In [1]:
import math
import pickle
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

## 评估函数 

In [2]:
# 导入ground truth
with open('../data_prep/robust04_gt.pkl', 'rb') as f: gt = pickle.load(f)
len(gt['301'])

448

In [3]:
def cal_F1(ranked_list: list, query: str, k: int, N_D: int) -> float:
    """
    计算F1 score
    k: 截断到第k个，从1计数
    """
    count = sum([ranked_list[i] in gt[query] for i in range(k)])
    p_k = count / k
    r_k = (count / N_D) if N_D != 0 else 0
    return (2 * p_k * r_k / (p_k + r_k)) if p_k + r_k != 0 else 0


def cal_DCG(ranked_list: list, query: str, k: int, N_D: int, penalty=-0.23, normalized=False) -> float:
    """
    计算DCG
    """
    value = 0
    for i in range(k): value += (1 / math.log(i + 2, 2)) if ranked_list[i] in gt[query] else (penalty / math.log(i + 2, 2))
    ideal_DCG = sum([1 / math.log(i + 2, 2) for i in range(k)]) if N_D >= k else sum([1 / math.log(i + 2, 2) for i in range(N_D)] + [penalty / math.log(i + 2, 2) for i in range(N_D, k)])
    return value if not normalized else value / ideal_DCG

## fixed-k的整套流程

In [4]:
def ori2rt(dataset_name: str, original_data: dict) -> dict:
    rt_data = {}
    for key in original_data: 
        rt_data[key] = [original_data[key]['retrieved_documents'][i]['doc_id'] for i in range(300)] if dataset_name == 'BM25' else [original_data[key][i]['doc_id'] for i in range(300)]
    return rt_data


def dataset_prepare(dataset_name: str) -> list:
    # 这里不需要train
    test_dataset_list = []
    if dataset_name == 'BM25':
        for i in range(1, 6):
            with open('../data_prep/BM25_results/split_{}/BM25_test_s{}.pkl'.format(i, i), 'rb') as f: test_dataset_list.append(pickle.load(f))
            test_dataset_list[-1] = ori2rt('BM25', test_dataset_list[-1])
    else:
        for i in range(1, 6):
            with open('../data_prep/drmm_results/split_{}/drmm_test_s{}.pkl'.format(i, i), 'rb') as f: test_dataset_list.append(pickle.load(f))
            test_dataset_list[-1] = ori2rt('DRMM', test_dataset_list[-1])
    return test_dataset_list


def test_scores(dataset: dict, gt: dict, k: list) -> float:
    F1_test, DCG_test, NDCG_test = [], [], []
    for key in dataset:
        N_D = sum(dataset[key][i] in gt[key] for i in range(300))
        F1_test.append(cal_F1(dataset[key], key, min(k[0] + 1, 300), N_D))
        DCG_test.append(cal_DCG(dataset[key], key, min(k[1] + 1, 300), N_D))
        NDCG_test.append(cal_DCG(dataset[key], key, min(k[2] + 1, 300), N_D, penalty=-0.78, normalized=True))
    F1, DCG, NDCG = np.mean(F1_test), np.mean(DCG_test), np.mean(NDCG_test)
    return F1, DCG, NDCG


def k_fold(dataset_name: str, fixed_k=50) -> float:
    test_dataset_list = dataset_prepare(dataset_name)
    # 在测试集上得到对应于固定k的结果列表
    F1_score_list, DCG_score_list, NDCG_score_list = [], [], []
    for dataset in test_dataset_list:
        results = test_scores(dataset, gt, [fixed_k - 1] * 3)
        F1_score_list.append(results[0])
        DCG_score_list.append(results[1])
        NDCG_score_list.append(results[2])
    return np.mean(F1_score_list), np.mean(DCG_score_list), np.mean(NDCG_score_list)

## BM25和DRMM的fixed-k结果

In [5]:
print('BM25 k = 5: {}'.format(k_fold('BM25', 5)))
print('BM25 k = 10: {}'.format(k_fold('BM25', 10)))
print('BM25 k = 50: {}'.format(k_fold('BM25', 50)))

BM25 k = 5: (0.17474618534792546, 1.0747526914281167, 0.06382077566452259)
BM25 k = 10: (0.24860544368215973, 1.489767906670021, -0.30712395236274276)
BM25 k = 50: (0.2922332254375074, 1.7999268124946117, 2.9727567739747025)


In [6]:
print('DRMM k = 5: {}'.format(k_fold('DRMM', 5)))
print('DRMM k = 10: {}'.format(k_fold('DRMM', 10)))
print('DRMM k = 50: {}'.format(k_fold('DRMM', 50)))

DRMM k = 5: (0.19574095035753317, 1.1092058369577242, 0.09973587877810637)
DRMM k = 10: (0.26104689301570955, 1.5333460982847114, -0.02948479582395295)
DRMM k = 50: (0.29725099683536305, 1.778329941465822, 3.044619810537594)
