In [1]:
import math
import pickle
import numpy as np
import matplotlib.pyplot as plt

## 工作路径

In [2]:
DATASET_BASE = '/home/LAB/wangd/graduation_project/ranked list truncation/dataset'

## 评估函数 

In [3]:
# 导入ground truth
with open('{}/robust04_gt.pkl'.format(DATASET_BASE), 'rb') as f:
    gt = pickle.load(f)
    for key in gt: gt[key] = set(gt[key])
len(gt['301'])

448

In [4]:
def cal_F1(ranked_list: list, k: int) -> float:
    """
    计算F1 score
    k: 截断到第k个，从1计数
    """
    count, N_D = sum(ranked_list[:k]), sum(ranked_list)
    p_k = count / k
    r_k = (count / N_D) if N_D != 0 else 0
    return (2 * p_k * r_k / (p_k + r_k)) if p_k + r_k != 0 else 0


def cal_DCG(ranked_list: list, k: int, penalty=-1) -> float:
    """
    计算DCG
    """
    value = 0
    for i in range(k): 
        value += (1 / math.log(i + 2, 2)) if ranked_list[i] else (penalty / math.log(i + 2, 2))
    return value

## fixed-k的整套流程

In [5]:
def dataset_prepare(dataset_name: str) -> list:
    # 这里不需要train
    with open('{}/{}_test.pkl'.format(DATASET_BASE, dataset_name), 'rb') as f:
        test_data = pickle.load(f)
    rl_data = {}
    for key in test_data: 
        rl_data[key] = list(map(lambda x: 1 if x in gt[key] else 0, test_data[key].keys()))
    return rl_data


def test_scores(dataset_name: str, k: list) -> float:
    dataset = dataset_prepare(dataset_name)
    F1_test, DCG_test = [], []
    for key in dataset:
        F1_test.append(cal_F1(dataset[key], k[0]))
        DCG_test.append(cal_DCG(dataset[key], k[1]))
    F1, DCG = np.mean(F1_test), np.mean(DCG_test)
    return F1, DCG

## Fixed-k结果

In [7]:
print('BM25 k = 5: {}'.format(test_scores('bm25', [5]*2)))
print('BM25 k = 10: {}'.format(test_scores('bm25', [10]*2)))
print('BM25 k = 50: {}'.format(test_scores('bm25', [50]*2)))

BM25 k = 5: (0.21689113823696973, 0.12490687197930528)
BM25 k = 10: (0.2767588393659435, -0.35275920060181265)
BM25 k = 50: (0.2916372147309367, -5.546589431773338)


In [8]:
print('DRMM k = 5: {}'.format(test_scores('drmm', [5]*2)))
print('DRMM k = 10: {}'.format(test_scores('drmm', [10]*2)))
print('DRMM k = 50: {}'.format(test_scores('drmm', [50]*2)))

DRMM k = 5: (0.19738517287426238, 0.24618597771981257)
DRMM k = 10: (0.2619403550629597, -0.20228091950106403)
DRMM k = 50: (0.29643726947393373, -5.085496368879661)


In [9]:
print('DRMM-TKS k = 5: {}'.format(test_scores('drmm_tks', [5]*2)))
print('DRMM-TKS k = 10: {}'.format(test_scores('drmm_tks', [10]*2)))
print('DRMM-TKS k = 50: {}'.format(test_scores('drmm_tks', [50]*2)))

DRMM-TKS k = 5: (0.20102252363081768, 2.8051911421233)
DRMM-TKS k = 10: (0.34098495770312615, 4.234743781053166)
DRMM-TKS k = 50: (0.6276142077551684, 7.6007801135598845)
