### Imports

In [17]:
import time
import numpy as np
import pandas as pd
from typing import Callable, Union, List, Dict, Any, Tuple
from tqdm.notebook import tqdm

In [7]:
from index_utils import IndexUtil
from experiment_utils import ExperimentUtil

### Prepare index/mappings/settings

In [8]:
INDEX_NAME = 'index_bm25_parameters'

In [9]:
INDEX = IndexUtil(INDEX_NAME)

In [10]:
def get_mappings(analyzer:str = 'stop-english_standard_analyzer'):
    return {
        "properties": {
            "article_id": {
                "type": "keyword"
            },
            "text": {
                "type": "text",
                "analyzer": analyzer
            },
        }
    }

In [11]:
def get_settings(b=0.75, k1=1.2):
    settings = IndexUtil.get_default_settings()
    IndexUtil.set_shards_in_settings(settings, shards=1)
    settings["index"] = {
      "similarity" : {
          "default" : {
            "type" : "BM25",
            "b": b,
            "k1": k1
          }
      }
    }
    return settings

In [12]:
def document_mapping_func(doc: Dict[str, Any])->Dict[str, Any]:
    return {
        'article_id': doc['uuid'],
        'text': doc['text'],
    }

### Load datasets

In [14]:
DOCUMENTS_SQUAD, QUESTIONS_SQUAD = ExperimentUtil.load_dataset('squad_10k')
DOCUMENTS_SWIFT, QUESTIONS_SWIFT = ExperimentUtil.load_dataset('swift_ui')

### Experiment

In [27]:
def test_b_k_impact(documents, questions, index = INDEX_NAME, query_fuc = INDEX.default_query):
    bs = []
    ks = []
    hits_top_10_ls = []
    hits_top_5_ls = []
    hits_top_3_ls = []
    hits_top_1_ls = []
    b_range=np.arange(0,1.05,0.05)
    k_range=np.arange(0,3.1,0.1)
    for b in tqdm(b_range):
        for k in tqdm(k_range):
            INDEX.delete_index()
            INDEX.create_index(get_mappings(), get_settings(b=b, k1=k))
            INDEX.index_all_docs(documents, document_mapping_func)
            bs.append(np.around(b, decimals=2))
            ks.append(np.around(k, decimals=2))
            time.sleep(1)
            all_hits = ExperimentUtil.validate(index, questions, query_fuc)
            hits_10, hits_5, hit_3, hits_1 = all_hits['hits@10'], all_hits['hits@5'], all_hits['hits@3'], all_hits['hits@1']
            hits_top_10_ls.append(hits_10)
            hits_top_5_ls.append(hits_5)
            hits_top_3_ls.append(hit_3)
            hits_top_1_ls.append(hits_1)
    return pd.DataFrame.from_dict({
        'b':bs,
        'k1':ks,
        'hits@10':hits_top_10_ls,
        'hits@5':hits_top_5_ls,
        'hits@3':hits_top_3_ls,
        'hits@1':hits_top_1_ls
    })

In [30]:
# bm25_params_swift_df = test_b_k_impact(DOCUMENTS_SWIFT, QUESTIONS_SWIFT)

In [32]:
# bm25_params_squad_df = test_b_k_impact(DOCUMENTS_SQUAD, QUESTIONS_SQUAD)

In [None]:
# bm25_params_swift_df.to_csv('results/bm_25_params_swift.csv')

In [None]:
# bm25_params_squad_df.to_csv('results/bm_25_params_squad.csv')

In [37]:
bm25_params_swift_df = pd.read_csv('results/bm_25_params_swift.csv', index_col=[0])
bm25_params_squad_df = pd.read_csv('results/bm_25_params_squad.csv', index_col=[0])

### Explore results

#### SWIFT

In [46]:
bm25_params_swift_df.sort_values('hits@10', ascending=False)

Unnamed: 0,b,k1,hits@10,hits@5,hits@3,hits@1
246,0.35,2.9,0.956522,0.875000,0.804348,0.581522
278,0.40,3.0,0.956522,0.875000,0.804348,0.586957
245,0.35,2.8,0.956522,0.880435,0.804348,0.576087
338,0.50,2.8,0.956522,0.869565,0.804348,0.586957
339,0.50,2.9,0.956522,0.869565,0.804348,0.592391
...,...,...,...,...,...,...
31,0.05,0.0,0.831522,0.722826,0.608696,0.418478
589,0.95,0.0,0.831522,0.722826,0.608696,0.418478
310,0.50,0.0,0.831522,0.722826,0.608696,0.418478
186,0.30,0.0,0.831522,0.722826,0.608696,0.418478


In [47]:
bm25_params_swift_df.sort_values('hits@1', ascending=False)

Unnamed: 0,b,k1,hits@10,hits@5,hits@3,hits@1
580,0.90,2.2,0.945652,0.864130,0.788043,0.630435
579,0.90,2.1,0.945652,0.864130,0.788043,0.630435
641,1.00,2.1,0.945652,0.864130,0.782609,0.625000
642,1.00,2.2,0.945652,0.864130,0.782609,0.625000
610,0.95,2.1,0.945652,0.864130,0.782609,0.625000
...,...,...,...,...,...,...
589,0.95,0.0,0.831522,0.722826,0.608696,0.418478
372,0.60,0.0,0.831522,0.722826,0.608696,0.418478
31,0.05,0.0,0.831522,0.722826,0.608696,0.418478
341,0.55,0.0,0.831522,0.722826,0.608696,0.418478


In [53]:
bm25_params_swift_df.sort_values('hits@10', ascending=False).head(20)

Unnamed: 0,b,k1,hits@10,hits@5,hits@3,hits@1
246,0.35,2.9,0.956522,0.875,0.804348,0.581522
278,0.4,3.0,0.956522,0.875,0.804348,0.586957
245,0.35,2.8,0.956522,0.880435,0.804348,0.576087
338,0.5,2.8,0.956522,0.869565,0.804348,0.586957
339,0.5,2.9,0.956522,0.869565,0.804348,0.592391
340,0.5,3.0,0.956522,0.869565,0.804348,0.586957
309,0.45,3.0,0.956522,0.869565,0.804348,0.586957
308,0.45,2.9,0.956522,0.869565,0.804348,0.586957
307,0.45,2.8,0.956522,0.869565,0.804348,0.586957
306,0.45,2.7,0.956522,0.869565,0.804348,0.586957


#### SQUAD

In [48]:
bm25_params_squad_df.sort_values('hits@10', ascending=False)

Unnamed: 0,b,k1,hits@10,hits@5,hits@3,hits@1
477,0.75,1.2,0.928,0.896,0.871,0.766
415,0.65,1.2,0.928,0.892,0.866,0.765
446,0.70,1.2,0.927,0.894,0.870,0.768
413,0.65,1.0,0.927,0.894,0.867,0.766
417,0.65,1.4,0.927,0.892,0.865,0.761
...,...,...,...,...,...,...
93,0.15,0.0,0.891,0.853,0.816,0.689
496,0.80,0.0,0.891,0.856,0.818,0.691
279,0.45,0.0,0.891,0.854,0.813,0.694
341,0.55,0.0,0.891,0.855,0.812,0.688


In [49]:
bm25_params_squad_df.sort_values('hits@1', ascending=False)

Unnamed: 0,b,k1,hits@10,hits@5,hits@3,hits@1
505,0.80,0.9,0.922,0.896,0.871,0.773
475,0.75,1.0,0.923,0.894,0.870,0.773
444,0.70,1.0,0.924,0.894,0.871,0.771
536,0.85,0.9,0.922,0.895,0.871,0.770
537,0.85,1.0,0.925,0.894,0.869,0.770
...,...,...,...,...,...,...
61,0.05,3.0,0.897,0.845,0.789,0.637
27,0.00,2.7,0.896,0.847,0.796,0.636
28,0.00,2.8,0.896,0.846,0.791,0.631
29,0.00,2.9,0.896,0.841,0.785,0.629


In [55]:
bm25_params_squad_df.sort_values('hits@1', ascending=False).head(20)

Unnamed: 0,b,k1,hits@10,hits@5,hits@3,hits@1
505,0.8,0.9,0.922,0.896,0.871,0.773
475,0.75,1.0,0.923,0.894,0.87,0.773
444,0.7,1.0,0.924,0.894,0.871,0.771
536,0.85,0.9,0.922,0.895,0.871,0.77
537,0.85,1.0,0.925,0.894,0.869,0.77
474,0.75,0.9,0.923,0.894,0.87,0.77
443,0.7,0.9,0.925,0.896,0.869,0.769
506,0.8,1.0,0.924,0.892,0.87,0.769
535,0.85,0.8,0.921,0.898,0.87,0.769
442,0.7,0.8,0.924,0.894,0.867,0.769
