In [22]:
import math
import pickle
import numpy as np

## 工作路径

In [23]:
DATASET_BASE = '/home/LAB/wangd/graduation_project/ranked list truncation/dataset'

## 评估函数

In [24]:
# 导入ground truth
with open('{}/robust04_gt.pkl'.format(DATASET_BASE), 'rb') as f:
    gt = pickle.load(f)
    for key in gt: gt[key] = set(gt[key])
len(gt['301'])

448

In [25]:
def cal_F1(ranked_list: list, k: int) -> float:
    """
    计算F1 score
    k: 截断到第k个，从1计数
    """
    count, N_D = sum(ranked_list[:k]), sum(ranked_list)
    p_k = count / k
    r_k = (count / N_D) if N_D != 0 else 0
    return (2 * p_k * r_k / (p_k + r_k)) if p_k + r_k != 0 else 0


def cal_DCG(ranked_list: list, k: int, penalty=-1) -> float:
    """
    计算DCG
    """
    value = 0
    for i in range(k): 
        value += (1 / math.log(i + 2, 2)) if ranked_list[i] else (penalty / math.log(i + 2, 2))
    return value

## Oracle的整体流程

In [26]:
def dataset_prepare(dataset_name: str) -> list:
    # 这里不需要train
    with open('{}/{}_test.pkl'.format(DATASET_BASE, dataset_name), 'rb') as f:
        test_data = pickle.load(f)
    rl_data = {}
    for key in test_data: 
        rl_data[key] = list(map(lambda x: 1 if x in gt[key] else 0, test_data[key].keys()))
    return rl_data

def test_scores(dataset_name: str) -> float:
    dataset = dataset_prepare(dataset_name)
    F1_k, DCG_k = [], []
    for key in dataset:
        per_k_F1, per_k_DCG = [0], [0]
        for i in range(1, 301):
            per_k_F1.append(cal_F1(dataset[key], i))
            per_k_DCG.append(cal_DCG(dataset[key], i))
        F1_k.append(per_k_F1)
        DCG_k.append(per_k_DCG)
    F1_best, DCG_best = np.max(np.array(F1_k), axis=1), np.max(np.array(DCG_k), axis=1)
    return np.mean(F1_best), np.mean(DCG_best)

## Oracle结果

In [27]:
print('BM25 Oracle: {}'.format(test_scores('bm25')))
print('DRMM Oracle: {}'.format(test_scores('drmm')))
print('DRMM-TKS Oracle: {}'.format(test_scores('drmm_tks')))

BM25 Oracle: (0.44132237243494216, 1.481949106601391)
DRMM Oracle: (0.4398188071248245, 1.751231140274798)
DRMM-TKS Oracle: (0.854701100644778, 13.583689296206474)
