In [1]:
import os
import pandas as pd
import numpy as np
from tqdm import tqdm
from rank_bm25 import BM25Okapi
from konlpy.tag import Mecab

In [2]:
train_df = pd.read_csv('./data/train.csv')

In [3]:
q_ids = train_df['SAMPLE_ID'].tolist()
q_answers = train_df['Source'].tolist()
queries = train_df['Question'].tolist()
print(len(q_ids), len(q_answers), len(queries))

496 496 496


In [4]:
doc_df = pd.read_csv('./data/processed_train.csv')

In [5]:
doc_ids = doc_df['doc_id'].tolist()
docs = doc_df['doc'].tolist()
print(len(doc_ids), len(docs))

1092 1092


In [6]:
mecab = Mecab()
tokenized_doc = [mecab.morphs(doc) for doc in tqdm(docs)]
tokenized_q = [mecab.morphs(q) for q in tqdm(queries)]

100%|██████████| 1092/1092 [00:01<00:00, 995.65it/s] 
100%|██████████| 496/496 [00:00<00:00, 15282.04it/s]


In [7]:
bm25 = BM25Okapi(tokenized_doc)

In [43]:
bm25 = BM25Okapi(tokenized_doc, b=1.0, k1=2.5)
top_k = 20

recall1, recall5, recall10, recall20 = 0, 0, 0, 0

for i, tq in enumerate(tqdm(tokenized_q)):
    answer = q_answers[i]
    doc_scores = bm25.get_scores(tq)
    top_indices = np.argsort(doc_scores)[-top_k:][::-1]

    for rank, idx in enumerate(top_indices):
        if answer in doc_ids[idx]:
            if rank < 1:
                recall1 += 1
            if rank < 5:
                recall5 += 1
            if rank < 10:
                recall10 += 1
            if rank < 20:
                recall20 += 1
            break

print(recall1)
print(recall5)
print(recall10)
print(recall20)
# 348
# 455
# 473
# 488

100%|██████████| 496/496 [00:02<00:00, 177.20it/s]

357
457
474
486





In [42]:

best_score = 0
best_params = {'b': 0, 'k1': 0}

b_values = [0.7, 0.8, 0.9, 1.0]
k1_values = [1.5, 1.75, 2.0, 2.25, 2.5]

def calculate_recall(bm25, top_k=20):
    recall1, recall5, recall10, recall20 = 0, 0, 0, 0
    for i, tq in enumerate(tqdm(tokenized_q)):
        answer = q_answers[i]
        doc_scores = bm25.get_scores(tq)
        top_indices = np.argsort(doc_scores)[-top_k:][::-1]

        for rank, idx in enumerate(top_indices):
            if answer in doc_ids[idx]:
                if rank < 1:
                    recall1 += 1
                if rank < 5:
                    recall5 += 1
                if rank < 10:
                    recall10 += 1
                if rank < 20:
                    recall20 += 1
                break
    total_queries = len(tokenized_q)
    recall_scores = {
        'recall1': recall1,
        'recall5': recall5,
        'recall10': recall10,
        'recall20': recall20
    }
    return recall_scores


for b in b_values:
    for k1 in k1_values:
        bm25 = BM25Okapi(tokenized_doc, b=b, k1=k1)
        recall_scores = calculate_recall(bm25)

        overall_score = recall_scores['recall1']
        print(f"b: {b}, k1: {k1}, Recall Scores: {recall_scores}, Overall Score: {overall_score:.4f}")

        if overall_score > best_score:
            best_score = overall_score
            best_params['b'] = b
            best_params['k1'] = k1

print(f"Best Score: {best_score:.4f}")
print(f"Best Parameters: b = {best_params['b']}, k1 = {best_params['k1']}")


  0%|          | 0/496 [00:00<?, ?it/s]

100%|██████████| 496/496 [00:02<00:00, 177.96it/s]


b: 0.7, k1: 1.5, Recall Scores: {'recall1': 318, 'recall5': 442, 'recall10': 471, 'recall20': 484}, Overall Score: 318.0000


100%|██████████| 496/496 [00:02<00:00, 178.56it/s]


b: 0.7, k1: 1.75, Recall Scores: {'recall1': 326, 'recall5': 445, 'recall10': 472, 'recall20': 481}, Overall Score: 326.0000


100%|██████████| 496/496 [00:02<00:00, 179.17it/s]


b: 0.7, k1: 2.0, Recall Scores: {'recall1': 329, 'recall5': 448, 'recall10': 471, 'recall20': 483}, Overall Score: 329.0000


100%|██████████| 496/496 [00:02<00:00, 179.01it/s]


b: 0.7, k1: 2.25, Recall Scores: {'recall1': 333, 'recall5': 448, 'recall10': 471, 'recall20': 481}, Overall Score: 333.0000


100%|██████████| 496/496 [00:02<00:00, 176.52it/s]


b: 0.7, k1: 2.5, Recall Scores: {'recall1': 337, 'recall5': 453, 'recall10': 471, 'recall20': 481}, Overall Score: 337.0000


100%|██████████| 496/496 [00:02<00:00, 177.21it/s]


b: 0.8, k1: 1.5, Recall Scores: {'recall1': 324, 'recall5': 448, 'recall10': 474, 'recall20': 486}, Overall Score: 324.0000


100%|██████████| 496/496 [00:02<00:00, 177.20it/s]


b: 0.8, k1: 1.75, Recall Scores: {'recall1': 328, 'recall5': 449, 'recall10': 474, 'recall20': 484}, Overall Score: 328.0000


100%|██████████| 496/496 [00:02<00:00, 180.05it/s]


b: 0.8, k1: 2.0, Recall Scores: {'recall1': 336, 'recall5': 451, 'recall10': 473, 'recall20': 484}, Overall Score: 336.0000


100%|██████████| 496/496 [00:02<00:00, 177.77it/s]


b: 0.8, k1: 2.25, Recall Scores: {'recall1': 342, 'recall5': 455, 'recall10': 474, 'recall20': 484}, Overall Score: 342.0000


100%|██████████| 496/496 [00:02<00:00, 178.24it/s]


b: 0.8, k1: 2.5, Recall Scores: {'recall1': 344, 'recall5': 454, 'recall10': 474, 'recall20': 483}, Overall Score: 344.0000


100%|██████████| 496/496 [00:02<00:00, 176.42it/s]


b: 0.9, k1: 1.5, Recall Scores: {'recall1': 327, 'recall5': 450, 'recall10': 474, 'recall20': 489}, Overall Score: 327.0000


100%|██████████| 496/496 [00:02<00:00, 177.67it/s]


b: 0.9, k1: 1.75, Recall Scores: {'recall1': 339, 'recall5': 452, 'recall10': 474, 'recall20': 488}, Overall Score: 339.0000


100%|██████████| 496/496 [00:02<00:00, 179.02it/s]


b: 0.9, k1: 2.0, Recall Scores: {'recall1': 344, 'recall5': 453, 'recall10': 475, 'recall20': 485}, Overall Score: 344.0000


100%|██████████| 496/496 [00:02<00:00, 176.35it/s]


b: 0.9, k1: 2.25, Recall Scores: {'recall1': 350, 'recall5': 454, 'recall10': 472, 'recall20': 485}, Overall Score: 350.0000


100%|██████████| 496/496 [00:02<00:00, 175.51it/s]


b: 0.9, k1: 2.5, Recall Scores: {'recall1': 350, 'recall5': 455, 'recall10': 472, 'recall20': 484}, Overall Score: 350.0000


100%|██████████| 496/496 [00:02<00:00, 176.85it/s]


b: 1.0, k1: 1.5, Recall Scores: {'recall1': 332, 'recall5': 453, 'recall10': 475, 'recall20': 489}, Overall Score: 332.0000


100%|██████████| 496/496 [00:02<00:00, 176.84it/s]


b: 1.0, k1: 1.75, Recall Scores: {'recall1': 343, 'recall5': 455, 'recall10': 475, 'recall20': 488}, Overall Score: 343.0000


100%|██████████| 496/496 [00:02<00:00, 177.63it/s]


b: 1.0, k1: 2.0, Recall Scores: {'recall1': 348, 'recall5': 455, 'recall10': 473, 'recall20': 488}, Overall Score: 348.0000


100%|██████████| 496/496 [00:02<00:00, 173.51it/s]


b: 1.0, k1: 2.25, Recall Scores: {'recall1': 351, 'recall5': 457, 'recall10': 475, 'recall20': 487}, Overall Score: 351.0000


100%|██████████| 496/496 [00:02<00:00, 178.04it/s]

b: 1.0, k1: 2.5, Recall Scores: {'recall1': 357, 'recall5': 457, 'recall10': 474, 'recall20': 486}, Overall Score: 357.0000
Best Score: 357.0000
Best Parameters: b = 1.0, k1 = 2.5



