# Test BM25 Index on all datasets and save results

In [None]:
import json
from pyserini.search.lucene import LuceneSearcher
from tqdm.auto import tqdm
import os
import numpy as np
import itertools


def test_bm25(index_dir, test_data_path, k1=0.4, b=0.4, topk=10, predict_type='entities', tqdm_desc="batches", verbose=True):
    '''Search for candidates using bm25 index and count metrics'''

    searcher = LuceneSearcher(index_dir)
    searcher.set_bm25(k1, b)

    rubq_test = json.load(
        open(test_data_path))['dataset']


    overall_precision, overall_recall, overall_f1 = 0, 0, 0
    samples_with_sparql = 0
    predicted_candidates = dict()

    for sample in tqdm(rubq_test, total=len(rubq_test), desc=tqdm_desc, display=verbose):
        query = sample["en_question"]
        query_id = sample['id']
        if not (predict_type in sample and sample[predict_type]['query']):
            continue
        gold_ids = list(sample[predict_type]['query'].keys())

        samples_with_sparql += 1

        result = searcher.search(query, k=topk)

        predicted_ids = []
        for res in result:
            doc_id = res.docid
            predicted_ids.append(doc_id)

        predicted_candidates[query_id] = predicted_ids

        true_positives = set(predicted_ids) & set(gold_ids)

        precision = len(true_positives) / len(predicted_ids) if predicted_ids else 0.0

        # Recall: Proportion of gold entities that are correctly predicted
        recall = len(true_positives) / len(gold_ids) if gold_ids else 0.0

        # F1-Score: Harmonic mean of Precision and Recall
        f1 = 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0.0

        overall_precision += precision
        overall_recall += recall
        overall_f1 += f1

    overall_precision /= samples_with_sparql
    overall_recall /= samples_with_sparql
    overall_f1 /= samples_with_sparql

    return predicted_candidates, overall_precision, overall_recall, overall_f1


def save_predicates_retriever_result(path, predicted_candidates):
    '''Save predicates candidates with their label and description'''

    with open('data/wikidata_relations_info.json', 'r', encoding='utf-8') as f:
        predicates_data = json.load(f)
    
    predicates_result = {}
    for index in predicted_candidates:
        predicates_result[index] = {}
        for p in predicted_candidates[index]:
            predicates_result[index][p] = {'label': predicates_data[p]['label'], 'description': predicates_data[p]['description']}

    with open(path, 'w', encoding='utf-8') as f:
        json.dump(predicates_result, f, ensure_ascii=False, indent=4)

    print('OK')


def save_entities_retriever_result(path, predicted_candidates, candidates_labels):
    '''Save entities candidates with their label and description'''
    predicted_candidates_result = {}

    for key in predicted_candidates:
        predicted_candidates_result[key] = {}
        for qid in predicted_candidates[key]:
            predicted_candidates_result[key][qid] = candidates_labels[qid]

    with open(path, 'w', encoding='utf-8') as f:
        json.dump(predicted_candidates_result, f, ensure_ascii=False, indent=4)

    print('OK')


def bm25_params_grid_search(index_dir, test_data_path, topk=10, predict_type='entities'):
    '''Search for optimal bm25 hyperparameters'''

    #k1_grid = np.logspace(-1, 1, 50)
    k1_grid = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 
               1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 
               6.0, 6.5, 7.0, 7.5, 8.0, 8.5, 9.0, 9.5, 10.0]
    b_grid = [0.01, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]

    grid = list(itertools.product(k1_grid, b_grid))

    results = {}

    best_recall = 0
    best_k1 = 0
    best_b = 0

    for k1, b in tqdm(grid):
        _, overall_precision, overall_recall, overall_f1 = test_bm25(index_dir, test_data_path, k1=k1, b=b, topk=topk, predict_type=predict_type, verbose=False)
        results[str((k1, b))] = {'precision': overall_precision, 'recall': overall_recall, 'f1': overall_f1}
        if overall_recall > best_recall:
            best_recall = overall_recall
            best_k1 = k1
            best_b = b
            print(f'best_recall: {best_recall}, k1: {k1}, b: {b}')

    return best_k1, best_b


def get_candidates_labels():
    '''Get labels and descriptions for candidates from wikidata files'''

    candidates_labels = {}

    ids_to_keep = None

    with open('ids_to_keep.txt', 'r', encoding='utf-8') as f:
        ids_to_keep = set(f.read().splitlines())

    labels_path='data/wikidata_dump/processed_dump/labels'

    for filename in tqdm(os.listdir(labels_path), desc='reading labels'):
        with open(f'{labels_path}/{filename}', 'r', encoding='utf-8') as f:
            for line in f:
                data = json.loads(line)
                qid = data['qid']
                if qid in ids_to_keep:
                    candidates_labels[qid] = {'label': data['label']}

    descriptions_path = 'data/wikidata_dump/processed_dump/descriptions'

    for filename in tqdm(os.listdir(descriptions_path), desc='reading descriptions'):
        with open(f'{descriptions_path}/{filename}', 'r', encoding='utf-8') as f:
            for line in f:
                data = json.loads(line)
                qid = data['qid']
                if qid in candidates_labels:
                    candidates_labels[qid]['description'] = data['description']

    return candidates_labels

In [None]:
ENTITIES_INDEX_DIR = "combined_data_index"
PREDICATES_INDEX_DIR = "predicates_index"

LCQUAD_PATH = 'data/preprocessed/lcquad_2.0/lcquad_2.0_test.json'
QALD_PATH = 'data/preprocessed/qald/qald_test.json'
PAT_PATH = 'data/preprocessed/pat/pat_test.json'
RUBQ_PATH = 'data/preprocessed/rubq/rubq_test.json'

RETRIEVAL_RESULT_SAVE_DIR = 'retrieval/retriever_result'

In [None]:
# get labels and descriptions for candidates from wikidata files 

candidates_labels = get_candidates_labels()

## LCQUAD

### Entities

In [None]:
# searching for optimal hyperparameters

best_k1, best_b = bm25_params_grid_search(ENTITIES_INDEX_DIR, LCQUAD_PATH, topk=10, predict_type='entities')

In [None]:
# get predictions and metrics

lcquad_entities_predicted_candidates, overall_precision, overall_recall, overall_f1 = test_bm25(ENTITIES_INDEX_DIR, LCQUAD_PATH, k1=2.947, b=0.2, predict_type='entities', topk=100)

print('Precision: ', overall_precision)
print('Recall: ', overall_recall)
print('F1: ', overall_f1)

# save results

save_entities_retriever_result(f'{RETRIEVAL_RESULT_SAVE_DIR}/lcquad_test_entities_retrieval.json', lcquad_entities_predicted_candidates, candidates_labels)

### Relations

In [None]:
best_k1, best_b = bm25_params_grid_search(RELATIONS_INDEX_DIR, LCQUAD_PATH, topk=10, predict_type='relations')

In [None]:
lcquad_relations_predicted_candidates, overall_precision, overall_recall, overall_f1 = test_bm25(RELATIONS_INDEX_DIR, LCQUAD_PATH, k1=5.18, b=0.01, predict_type='relations', topk=100)

print('Precision: ', overall_precision)
print('Recall: ', overall_recall)
print('F1: ', overall_f1)

save_predicates_retriever_result(f'{RETRIEVAL_RESULT_SAVE_DIR}/lcquad_test_predicates_retrieval.json', lcquad_relations_predicted_candidates)

## QALD

### Entities

In [None]:
best_k1, best_b = bm25_params_grid_search(ENTITIES_INDEX_DIR, QALD_PATH, topk=10, predict_type='entities')

In [None]:
qald_entities_predicted_candidates, overall_precision, overall_recall, overall_f1 = test_bm25(ENTITIES_INDEX_DIR, QALD_PATH, k1=2.95, b=0.2, predict_type='entities', topk=100)

print('Precision: ', overall_precision)
print('Recall: ', overall_recall)
print('F1: ', overall_f1)

### Relations

In [None]:
best_k1, best_b = bm25_params_grid_search(RELATIONS_INDEX_DIR, QALD_PATH, topk=10, predict_type='relations')

In [None]:
qald_relations_predicted_candidates, overall_precision, overall_recall, overall_f1 = test_bm25(RELATIONS_INDEX_DIR, QALD_PATH, k1=5.18, b=0.01, predict_type='relations', topk=100)

print('Precision: ', overall_precision)
print('Recall: ', overall_recall)
print('F1: ', overall_f1)

save_predicates_retriever_result(f'{RETRIEVAL_RESULT_SAVE_DIR}/qald_test_predicates_retrieval.json', qald_relations_predicted_candidates)

## PAT

### Entities

In [None]:
best_k1, best_b = bm25_params_grid_search(ENTITIES_INDEX_DIR, PAT_PATH, topk=10, predict_type='entities')

In [None]:
pat_entities_predicted_candidates, overall_precision, overall_recall, overall_f1 = test_bm25(ENTITIES_INDEX_DIR, PAT_PATH, k1=1.0, b=0.7, predict_type='entities', topk=100)

print('Precision: ', overall_precision)
print('Recall: ', overall_recall)
print('F1: ', overall_f1)

save_entities_retriever_result(f'{RETRIEVAL_RESULT_SAVE_DIR}/pat_test_entities_retrieval.json', pat_entities_predicted_candidates, candidates_labels)

### Relations

In [None]:
best_k1, best_b = bm25_params_grid_search(PREDICATES_INDEX_DIR, PAT_PATH, topk=10, predict_type='relations')

In [None]:
pat_relations_predicted_candidates, overall_precision, overall_recall, overall_f1 = test_bm25(PREDICATES_INDEX_DIR, PAT_PATH, k1=0.1, b=0.01, predict_type='relations', topk=100)

print('Precision: ', overall_precision)
print('Recall: ', overall_recall)
print('F1: ', overall_f1)

save_predicates_retriever_result(f'{RETRIEVAL_RESULT_SAVE_DIR}/pat_test_predicates_retrieval.json', pat_relations_predicted_candidates)

## RUBQ

### Entities

In [None]:
best_k1, best_b = bm25_params_grid_search(ENTITIES_INDEX_DIR, RUBQ_PATH, topk=10, predict_type='entities')

In [None]:
rubq_entities_predicted_candidates, overall_precision, overall_recall, overall_f1 = test_bm25(ENTITIES_INDEX_DIR, RUBQ_PATH, k1=1.39, b=0.4, predict_type='entities', topk=100)

print('Precision: ', overall_precision)
print('Recall: ', overall_recall)
print('F1: ', overall_f1)

save_entities_retriever_result(f'{RETRIEVAL_RESULT_SAVE_DIR}/rubq_test_entities_retrieval.json', rubq_entities_predicted_candidates, candidates_labels)

### Relations

In [None]:
best_k1, best_b = bm25_params_grid_search(RELATIONS_INDEX_DIR, RUBQ_PATH, topk=10, predict_type='relations')

In [None]:
rubq_relations_predicted_candidates, overall_precision, overall_recall, overall_f1 = test_bm25(RELATIONS_INDEX_DIR, RUBQ_PATH, k1=5.18, b=0.01, predict_type='relations', topk=100)

print('Precision: ', overall_precision)
print('Recall: ', overall_recall)
print('F1: ', overall_f1)

save_predicates_retriever_result(f'{RETRIEVAL_RESULT_SAVE_DIR}/rubq_test_predicates_retrieval.json', rubq_relations_predicted_candidates)