In [28]:
import elasticsearch.helpers
import json
from prettytable import PrettyTable

def success_at_1 (relevant, retrieved):
    if len(retrieved) > 0 and retrieved[0] in relevant:
        return 1
    else:
        return 0
    
def success_at_5(relevant, retrieved):
    for retrieved_single in retrieved[:5]:
        if retrieved_single in relevant:
            return 1
    else:
        return 0

def success_at_10(relevant, retrieved):
    for retrieved_single in retrieved[:10]:
        if retrieved_single in relevant:
            return 1
    else:
        return 0
    
    
def precision(relevant, retrieved):
    relevant_items_retrieved_counter = 0
    for r in retrieved:
        if r in relevant:
            relevant_items_retrieved_counter += 1
    return relevant_items_retrieved_counter / len(retrieved) if len(retrieved) > 0 else 0.0
    
def recall(relevant, retrieved):
    relevant_items_retrieved_counter = 0
    for r in retrieved:
        if r in relevant:
            relevant_items_retrieved_counter += 1
    return relevant_items_retrieved_counter / len(relevant)
    
def f_measure(relevant, retrieved):
    P = precision(relevant, retrieved)
    R = recall(relevant, retrieved)
    if P + R == 0.0:
        return 0.0
    return 2 * P * R / (P + R)
    
def precision_at_k(relevant, retrieved, k):
    return precision(relevant, retrieved[:k])

def r_precision(relevant, retrieved):
    return precision(relevant, retrieved[:len(relevant)])

def interpolated_precision_at_recall_X (relevant, retrieved, X):
    max_precision = 0.0
    for i in range(len(retrieved)):
        if recall(relevant, retrieved[:i]) >= X:
            precision_at_i = precision_at_k(relevant, retrieved, i)
            if precision_at_i > max_precision:
                max_precision = precision_at_i
    return max_precision

def average_precision(relevant, retrieved):
    sum_avg = 0.0
    for retrieved_index in range(len(retrieved)):
        if retrieved[retrieved_index] in relevant:
            sum_avg += precision_at_k(relevant, retrieved, retrieved_index + 1)
    return sum_avg / len(relevant)
    
def read_qrels_file(qrels_file):  # reads the content of he qrels file
    trec_relevant = dict()  # query_id -> set([docid1, docid2, ...])
    with open(qrels_file, 'r') as qrels:
        for line in qrels:
            (qid, q0, doc_id, rel) = line.strip().split()
            if qid not in trec_relevant:
                trec_relevant[qid] = set()
            if rel == "1":
                trec_relevant[qid].add(doc_id)
    return trec_relevant

def read_run_file(run_file):  
    # read the content of the run file produced by our IR system 
    # (in the following exercises you will create your own run_files)
    trec_retrieved = dict()  # query_id -> [docid1, docid2, ...]
    with open(run_file, 'r') as run:
        for line in run:
            (qid, q0, doc_id, rank, score, tag) = line.strip().split()
            if qid not in trec_retrieved:
                trec_retrieved[qid] = []
            trec_retrieved[qid].append(doc_id) 
    return trec_retrieved
    
def read_eval_files(qrels_file, run_file):
    return read_qrels_file(qrels_file), read_run_file(run_file)

def mean_average_precision(all_relevant, all_retrieved):    
    sum_map = []
    for key in all_retrieved:
        sum_map.append(average_precision(all_relevant[key], all_retrieved[key]))
    return sum(sum_map) / len(sum_map)

def mean_metric(measure, all_relevant, all_retrieved):
    total = 0
    count = 0
    for qid in all_relevant:
        relevant  = all_relevant[qid]
        retrieved = all_retrieved.get(qid, [])
        value = measure(relevant, retrieved)
        total += value
        count += 1
    return "mean " + measure.__name__, total / count

def trec_eval(qrels_file, run_file):
    def precision_at_1(rel, ret): return precision_at_k(rel, ret, k=1)
    def precision_at_5(rel, ret): return precision_at_k(rel, ret, k=5)
    def precision_at_10(rel, ret): return precision_at_k(rel, ret, k=10)
    def precision_at_50(rel, ret): return precision_at_k(rel, ret, k=50)
    def precision_at_100(rel, ret): return precision_at_k(rel, ret, k=100)
    def precision_at_recall_00(rel, ret): return interpolated_precision_at_recall_X(rel, ret, X=0.0)
    def precision_at_recall_01(rel, ret): return interpolated_precision_at_recall_X(rel, ret, X=0.1)
    def precision_at_recall_02(rel, ret): return interpolated_precision_at_recall_X(rel, ret, X=0.2)
    def precision_at_recall_03(rel, ret): return interpolated_precision_at_recall_X(rel, ret, X=0.3)
    def precision_at_recall_04(rel, ret): return interpolated_precision_at_recall_X(rel, ret, X=0.4)
    def precision_at_recall_05(rel, ret): return interpolated_precision_at_recall_X(rel, ret, X=0.5)
    def precision_at_recall_06(rel, ret): return interpolated_precision_at_recall_X(rel, ret, X=0.6)
    def precision_at_recall_07(rel, ret): return interpolated_precision_at_recall_X(rel, ret, X=0.7)
    def precision_at_recall_08(rel, ret): return interpolated_precision_at_recall_X(rel, ret, X=0.8)
    def precision_at_recall_09(rel, ret): return interpolated_precision_at_recall_X(rel, ret, X=0.9)
    def precision_at_recall_10(rel, ret): return interpolated_precision_at_recall_X(rel, ret, X=1.0)

    (all_relevant, all_retrieved) = read_eval_files(qrels_file, run_file)
    
    unknown_qids = set(all_retrieved.keys()).difference(all_relevant.keys())
    if len(unknown_qids) > 0:
        raise ValueError("Unknown qids in run: {}".format(sorted(list(unknown_qids))))

    metrics = [success_at_1,
               success_at_5,
               success_at_10,
               r_precision,
               precision_at_1,
               precision_at_5,
               precision_at_10,
               precision_at_50,
               precision_at_100,
               precision_at_recall_00,
               precision_at_recall_01,
               precision_at_recall_02,
               precision_at_recall_03,
               precision_at_recall_04,
               precision_at_recall_05,
               precision_at_recall_06,
               precision_at_recall_07,
               precision_at_recall_08,
               precision_at_recall_09,
               precision_at_recall_10,
               average_precision,
               f_measure,
               recall]
    return [mean_metric(metric, all_relevant, all_retrieved) for metric in metrics]

def print_trec_eval(qrels_file, run_file):
    results_var = trec_eval(qrels_file, run_file)
    print("Results for {}".format(run_file))
    for (metric, score) in results_var:
        print("{:<30} {:.4}".format(metric, score))
        
def sign_test_values(measure, qrels_file, run_file_1, run_file_2):
    all_relevant = read_qrels_file(qrels_file)
    all_retrieved_1 = read_run_file(run_file_1)
    all_retrieved_2 = read_run_file(run_file_2)
    better = 0
    worse  = 0
    
    for key in all_relevant:
        measure_file_1 = measure(all_relevant[key], all_retrieved_1[key])
        measure_file_2 = measure(all_relevant[key], all_retrieved_2[key])
        if measure_file_1 > measure_file_2:
            better += 1
        elif measure_file_1 <  measure_file_2:
            worse += 1
    
    return better, worse
    
def precision_at_rank_5(rel, ret):
    return precision_at_k(rel, ret, k=5)


def read_documents(file_name):
    """
    Returns a generator of documents to be indexed by elastic, read from file_name
    """
    with open(file_name, 'r') as documents:
        for line in documents:
            doc_line = json.loads(line)
            if 'index' in doc_line:
                id_param = doc_line['index']['_id']
            elif 'PMID' in doc_line:
                doc_line['_id'] = id_param
                yield doc_line
            else:
                raise ValueError('Woops, error in index file')

def create_index(es_param, index_name, body=None):
    # delete index when it already exists
    if body is None:
        body = {}
    es_param.indices.delete(index=index_name, ignore=[400, 404])
    # create the index 
    es_param.indices.create(index=index_name, body=body)
                
def index_documents(es_param, collection_file_name, index_name, body=None):
    if body is None:
        body = {}
    create_index(es_param, index_name, body)
    # bulk index the documents from file_name
    return elasticsearch.helpers.bulk(
        es_param, 
        read_documents(collection_file_name),
        index=index_name,
        chunk_size=2000,
        request_timeout=30
    )

def make_trec_run(es_param, topics_file_name, run_file_name, index_name="genomics", run_name="test"):
    with open(f"run/{run_file_name}", 'w') as run_file:
        with open(topics_file_name, 'r') as test_queries:
            for line in test_queries:
                (qid, query) = line.strip().split('\t')
                
                search_query = {
                    "query": {
                        "bool": {
                            "should": [
                                {"match": {"TI": query}},  # Search in the title field.
                                {"match": {"AB": query}}   # Search in the abstract field.
                            ]
                        }
                    },
                    "size": 1000  # You can adjust the number of results as needed.
                }
                
                # Execute the query against the specified Elasticsearch index.
                search_results = es_param.search(index=index_name, body=search_query)
                
                # Process the search results and write them to the run file in TREC format.
                for rank, hit in enumerate(search_results['hits']['hits']):
                    pmid = hit['_source']['PMID']
                    score = hit['_score']
                    run_line = f"{qid} Q0 {pmid} {rank + 1} {score} {run_name}\n"
                    run_file.write(run_line)
                # 

def evaluate_models(models_param, qrels_file):
    results_var = {}
    for model_name in models_param:
        run_file_name = f"run/{model_name}.run"
        eval_results = trec_eval(qrels_file, run_file_name)
        results_var[model_name] = {metric: score for metric, score in eval_results}
    return results_var

def determine_best_model(results_param, metric):
    """Determine the best model based on a specific metric."""
    return max(results_param, key=lambda model: results_param[model][metric])

def print_results_table(results_param):
    # Extract model names
    model_names = list(results_param.keys())
    
    # Extract metric names from the first model's results
    metrics = list(next(iter(results_param.values())).keys())
    
    # Create a table with headers
    table = PrettyTable()
    table.field_names = ["Metric"] + model_names
    
    # Add rows for each metric's results across all models
    for metric in metrics:
        values = [metric] + [f"{results_param[model_name][metric]:.4f}" for model_name in model_names]
        table.add_row(values)
    
    # Print the table
    print(table)

    # Determine and print the best model based on mean average precision (MAP)
    best_model = determine_best_model(results_param, 'mean average_precision')
    print(f"\n\033[1;34mThe best model based on mean average precision (MAP) is: {best_model}\033[0m")  # Emphasize with blue color
    
def generate_dfr_tokenization(strategy):
    return {
        "settings": {
            "index": {
                "similarity": {
                    "custom_dfr_settings": {
                        "type": "DFR",
                        "basic_model": "g",
                        "after_effect": "l",
                        "normalization": "h2",
                        "normalization.h2.c": "3.0"
                    }
                }
            },
            "analysis": {
                "tokenizer": {
                    "my_tokenizer": strategy
                },
                "analyzer": {
                    "my_analyzer": {
                        "type": "custom",
                        "tokenizer": "my_tokenizer"
                    }
                }
            }
        },
        "mappings": {
            "properties": {
                "TI": {
                    "type": "text",
                    "similarity": "custom_dfr_settings",
                    "analyzer": "my_analyzer"
                },
                "AB": {
                    "type": "text",
                    "similarity": "custom_dfr_settings",
                    "analyzer": "my_analyzer"
                }
            }
        }
    }

es = elasticsearch.Elasticsearch('http://localhost:9200')

def dfr_index(body_dfr):
    body_type_v = body_dfr["type"]
    generated_body = generate_dfr_tokenization(body_dfr)
    index_documents(es, 'data01/FIR-s05-medline.json', f"genomics-{body_type_v}", body=generated_body)
    return body_type_v

def dfr_run(body_type_file):
    make_trec_run(es, 'data01/FIR-s05-training-queries-simple.txt', f"dfr_{body_type_file}.run", f"genomics-{body_type_file}")

In [None]:
body_type = dfr_index({
    "type": "standard"
})

In [None]:
dfr_run(body_type)

In [None]:
body_type = dfr_index({
    "type": "whitespace"
})

In [None]:
dfr_run(body_type)

In [None]:
body_type = dfr_index({
    "type": "keyword",
    "token_chars": [
        "letter",
        "digit"
    ]
})

In [None]:
dfr_run(body_type)

In [None]:
body_type = dfr_index({
    "type": "ngram",
    "min_gram": 4,
    "max_gram": 5,
    "token_chars": [
        "letter",
        "digit"
    ]
})

In [None]:
dfr_run(body_type)

In [None]:
body_type = dfr_index({
    "type": "edge_ngram",
    "min_gram": 2,
    "max_gram": 20,
    "token_chars": [
        "letter",
        "digit"
    ]
})

In [None]:
dfr_run(body_type)

In [41]:
# Example usage:
# models = ["boolean", "dirichelet", "lmjelinekmercer", "baseline", "dfr_settings", "pattern_settings"]
models = ["dfr_standard", "dfr_whitespace", "dfr_keyword", "dfr_edge_ngram", "dfr_ngram"]
results = evaluate_models(models, 'data01/FIR-s05-training-qrels.txt')
print_results_table(results)

+-----------------------------+--------------+----------------+-------------+----------------+-----------+
|            Metric           | dfr_standard | dfr_whitespace | dfr_keyword | dfr_edge_ngram | dfr_ngram |
+-----------------------------+--------------+----------------+-------------+----------------+-----------+
|      mean success_at_1      |    0.1579    |     0.0526     |    0.0000   |     0.1053     |   0.1053  |
|      mean success_at_5      |    0.2895    |     0.1053     |    0.0000   |     0.2368     |   0.2368  |
|      mean success_at_10     |    0.4211    |     0.1316     |    0.0000   |     0.2895     |   0.2895  |
|       mean r_precision      |    0.1188    |     0.0566     |    0.0000   |     0.0870     |   0.0730  |
|     mean precision_at_1     |    0.1579    |     0.0526     |    0.0000   |     0.1053     |   0.1053  |
|     mean precision_at_5     |    0.0947    |     0.0368     |    0.0000   |     0.0632     |   0.0632  |
|     mean precision_at_10    |    0.