In [1]:
import math
import pickle
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

## 评估函数

In [2]:
# 导入ground truth
with open('../data_prep/robust04_data/robust04_gt.pkl', 'rb') as f: gt = pickle.load(f)
len(gt['301'])

448

In [3]:
def cal_F1(ranked_list: list, query: str, k: int, N_D: int) -> float:
    """
    计算F1 score
    k: 截断到第k个，从1计数
    """
    count = sum([ranked_list[i] in gt[query] for i in range(k)])
    p_k = (count / k) if k != 0 else 0
    r_k = (count / N_D) if N_D != 0 else 0
    return (2 * p_k * r_k / (p_k + r_k)) if p_k + r_k != 0 else 0


def cal_DCG(ranked_list: list, query: str, k: int, N_D: int, penalty=-0.25, normalized=False) -> float:
    """
    计算DCG
    """
    value = 0
    for i in range(k): value += (1 / math.log(i + 2, 2)) if ranked_list[i] in gt[query] else (penalty / math.log(i + 2, 2))
    ideal_DCG = sum([1 / math.log(i + 2, 2) for i in range(k)]) if N_D >= k else sum([1 / math.log(i + 2, 2) for i in range(N_D)] + [penalty / math.log(i + 2, 2) for i in range(N_D, k)])
    return value if not normalized else value / ideal_DCG

## Oracle的整体流程(自制数据集)

In [9]:
def ori2rt(dataset_name: str, original_data: dict) -> dict:
    rt_data = {}
    for key in original_data: 
        rt_data[key] = [original_data[key]['retrieved_documents'][i]['doc_id'] for i in range(300)] if dataset_name == 'BM25' else [original_data[key][i]['doc_id'] for i in range(300)]
    return rt_data


def dataset_prepare(dataset_name: str) -> list:
    test_dataset_list = []
    if dataset_name == 'BM25':
        for i in range(1, 6):
            with open('../data_prep/my_results/BM25_results/split_{}/BM25_test_s{}.pkl'.format(i, i), 'rb') as f: test_dataset_list.append(pickle.load(f))
            test_dataset_list[-1] = ori2rt('BM25', test_dataset_list[-1])
    else:
        for i in range(1, 6):
            with open('../data_prep/my_results/drmm_results/split_{}/drmm_test_s{}.pkl'.format(i, i), 'rb') as f: test_dataset_list.append(pickle.load(f))
            test_dataset_list[-1] = ori2rt('DRMM', test_dataset_list[-1])
    return test_dataset_list


def test_scores(dataset: dict, gt: dict) -> tuple:
    F1_test, DCG_test, NDCG_test = [], [], []
    for key in dataset:
        per_k_F1, per_k_DCG, N_D = [], [], sum(dataset[key][i] in gt[key] for i in range(300))
        for i in range(300): 
            per_k_F1.append(cal_F1(dataset[key], key, i, N_D))
            per_k_DCG.append(cal_DCG(dataset[key], key, i, N_D))
        F1_test.append(max(per_k_F1))
        DCG_test.append(max(per_k_DCG))
    F1_result, DCG_result = np.mean(F1_test), np.mean(DCG_test)
    return F1_result, DCG_result


def k_fold(dataset_name: str) -> float:
    test_dataset_list = dataset_prepare(dataset_name)
    # 在测试集中得到先验指导的结果列表
    F1_list, DCG_list = [], []
    for dataset in test_dataset_list:
        F1_result, DCG_result = test_scores(dataset, gt)
        F1_list.append(F1_result)
        DCG_list.append(DCG_result)
    return np.mean(F1_list), np.mean(DCG_list)

## AAAI-2021 ranked_list and processing

In [11]:
def ori2rl(original_data: dict) -> dict:
    rl_data = {}
    for key in original_data: rl_data[key] = list(original_data[key].keys())
    return rl_data


def dataset_prepare(dataset_name: str) -> dict:
    if dataset_name == 'BM25':
        with open('../data_prep/wc_results/bm25_test.pkl', 'rb') as f: test_dataset = pickle.load(f)
        test_ranked_list = ori2rl(test_dataset)
    else:
        with open('../data_prep/wc_results/drmm_test.pkl', 'rb') as f: test_dataset = pickle.load(f)
        test_ranked_list = ori2rl(test_dataset)
    return test_ranked_list


def test_scores(dataset_name: str, gt: dict) -> tuple:
    dataset = dataset_prepare(dataset_name)
    F1_test, DCG_test = [], []
    for key in dataset:
        per_k_F1, per_k_DCG, N_D = [], [], sum(dataset[key][i] in gt[key] for i in range(300))
        for i in range(300): 
            per_k_F1.append(cal_F1(dataset[key], key, i, N_D))
            per_k_DCG.append(cal_DCG(dataset[key], key, i, N_D))
        F1_test.append(max(per_k_F1))
        DCG_test.append(max(per_k_DCG))
    F1_result, DCG_result = np.mean(F1_test), np.mean(DCG_test)
    return F1_result, DCG_result

## 经过召回优化的drmm-tks处理流程

In [7]:
def ori2rl(original_data: dict) -> dict:
    rl_data = {}
    for key in original_data: 
        rl_data[key] = [original_data[key][i]['doc_id'] for i in range(300)]
    return rl_data


def dataset_prepare() -> list:
    test_dataset_list = []
    for i in range(1, 6):
        with open('../data_prep/my_results/drmm_tks_results/split_{}/drmm_tks_test_s{}.pkl'.format(i, i), 'rb') as f: test_dataset_list.append(pickle.load(f))
        test_dataset_list[-1] = ori2rl(test_dataset_list[-1])
    return test_dataset_list


def test_scores(dataset: dict, gt: dict) -> tuple:
    F1_test, DCG_test = [], []
    for key in dataset:
        per_k_F1, per_k_DCG, N_D = [], [], sum(dataset[key][i] in gt[key] for i in range(300))
        for i in range(300): 
            per_k_F1.append(cal_F1(dataset[key], key, i, N_D))
            per_k_DCG.append(cal_DCG(dataset[key], key, i, N_D))
        F1_test.append(max(per_k_F1))
        DCG_test.append(max(per_k_DCG))
    F1_result, DCG_result = np.mean(F1_test), np.mean(DCG_test)
    return F1_result, DCG_result


def k_fold() -> float:
    test_dataset_list = dataset_prepare()
    # 在测试集中得到先验指导的结果列表
    F1_list, DCG_list = [], []
    for dataset in test_dataset_list:
        F1_result, DCG_result = test_scores(dataset, gt)
        F1_list.append(F1_result)
        DCG_list.append(DCG_result)
    return np.mean(F1_list), np.mean(DCG_list)

## BM25和DRMM的Oracle结果(吴晨)

In [12]:
print('BM25 Oracle: {}'.format(test_scores('BM25', gt)))
print('DRMM Oracle: {}'.format(test_scores('DRMM', gt)))

BM25 Oracle: (0.44132237243494216, 2.7876816987782345)
DRMM Oracle: (0.4398188071248245, 3.1463638415401878)


## BM25和DRMM的Oracle结果(自制)

In [10]:
print('BM25 Oracle: {}'.format(k_fold('BM25')))
print('DRMM Oracle: {}'.format(k_fold('DRMM')))

BM25 Oracle: (0.4294348991288629, 3.2136030192457157)
DRMM Oracle: (0.45171301531736374, 3.1570649715650303)


## DRMM-TKS Oracle结果

In [8]:
print('DRMM_TKS Oracle: {}'.format(k_fold()))

DRMM_TKS Oracle: (0.8392034355460855, 12.946841402517625)
