In [1]:
import math
import pickle
import numpy as np
import matplotlib.pyplot as plt

## 工作路径

In [2]:
ROBUST_BASE = '/home/LAB/wangd/graduation_project/ranked list truncation/dataset/robust04'
MQ_BASE = '/home/LAB/wangd/graduation_project/ranked list truncation/dataset/mq2007'

## 评估函数

In [3]:
def cal_F1(ranked_list: list, k: int) -> float:
    """
    计算F1 score
    k: 截断到第k个，从1计数
    """
    count, N_D = sum(ranked_list[:k]), sum(ranked_list)
    p_k = count / k
    r_k = (count / N_D) if N_D != 0 else 0
    return (2 * p_k * r_k / (p_k + r_k)) if p_k + r_k != 0 else 0


def cal_DCG(ranked_list: list, k: int, penalty=-1) -> float:
    """
    计算DCG
    """
    value = 0
    for i in range(k): 
        value += (1 / math.log(i + 2, 2)) if ranked_list[i] else (penalty / math.log(i + 2, 2))
    return value

## fixed-k的整套流程

In [4]:
def dataset_prepare(dataset_name: str, DATASET_BASE) -> list:
    # 这里不需要train
    with open('{}/{}_test.pkl'.format(DATASET_BASE, dataset_name), 'rb') as f:
        test_data = pickle.load(f)
    rl_data = {}
    for key in test_data: 
        rl_data[key] = list(map(lambda x: 1 if x in gt[key] else 0, test_data[key].keys()))
    return rl_data


def test_scores(dataset_name: str, k: list, DATASET_BASE) -> float:
    dataset = dataset_prepare(dataset_name, DATASET_BASE)
    F1_test, DCG_test = [], []
    for key in dataset:
        F1_test.append(cal_F1(dataset[key], k[0]))
        DCG_test.append(cal_DCG(dataset[key], k[1]))
    F1, DCG = np.mean(F1_test), np.mean(DCG_test)
    return F1, DCG

# Robust04

In [5]:
# 导入ground truth
with open('{}/gt.pkl'.format(ROBUST_BASE), 'rb') as f:
    gt = pickle.load(f)
    for key in gt: gt[key] = set(gt[key])
len(gt['301'])

448

## Fixed-k结果

In [6]:
print('BM25 k = 5: {}'.format(test_scores('bm25', [5]*2, ROBUST_BASE)))
print('BM25 k = 10: {}'.format(test_scores('bm25', [10]*2, ROBUST_BASE)))
print('BM25 k = 50: {}'.format(test_scores('bm25', [30]*2, ROBUST_BASE)))

BM25 k = 5: (0.21689113823696973, 0.12490687197930528)
BM25 k = 10: (0.2767588393659435, -0.35275920060181265)
BM25 k = 50: (0.3049257553537696, -2.9305045984628264)


In [7]:
print('DRMM k = 5: {}'.format(test_scores('drmm', [5]*2, ROBUST_BASE)))
print('DRMM k = 10: {}'.format(test_scores('drmm', [10]*2, ROBUST_BASE)))
print('DRMM k = 50: {}'.format(test_scores('drmm', [30]*2, ROBUST_BASE)))

DRMM k = 5: (0.19738517287426238, 0.24618597771981257)
DRMM k = 10: (0.2619403550629597, -0.20228091950106403)
DRMM k = 50: (0.30633865387481135, -2.5531886302444327)


In [8]:
print('DRMM-TKS k = 5: {}'.format(test_scores('drmm_tks', [5]*2, ROBUST_BASE)))
print('DRMM-TKS k = 10: {}'.format(test_scores('drmm_tks', [10]*2, ROBUST_BASE)))
print('DRMM-TKS k = 50: {}'.format(test_scores('drmm_tks', [30]*2, ROBUST_BASE)))

DRMM-TKS k = 5: (0.20102252363081768, 2.8051911421233)
DRMM-TKS k = 10: (0.34098495770312615, 4.234743781053166)
DRMM-TKS k = 50: (0.584051644295087, 7.095262902068821)


# MQ2007

In [9]:
# 导入ground truth
with open('{}/gt.pkl'.format(MQ_BASE), 'rb') as f:
    gt = pickle.load(f)
    for key in gt: gt[key] = set(gt[key])
len(gt['10'])

16

## Fixed-k结果

In [10]:
print('BM25 k = 5: {}'.format(test_scores('bm25', [5]*2, MQ_BASE)))
print('BM25 k = 10: {}'.format(test_scores('bm25', [10]*2, MQ_BASE)))
print('BM25 k = 30: {}'.format(test_scores('bm25', [30]*2, MQ_BASE)))

BM25 k = 5: (0.23304048597449337, -0.5218541842104699)
BM25 k = 10: (0.3230539250321036, -0.8666806662382415)
BM25 k = 30: (0.4397382645871278, -2.377816019581174)


In [11]:
print('DRMM-TKS k = 5: {}'.format(test_scores('drmm_tks', [5]*2, MQ_BASE)))
print('DRMM-TKS k = 10: {}'.format(test_scores('drmm_tks', [10]*2, MQ_BASE)))
print('DRMM-TKS k = 30: {}'.format(test_scores('drmm_tks', [30]*2, MQ_BASE)))

DRMM-TKS k = 5: (0.4102969651809003, 0.9609889342923913)
DRMM-TKS k = 10: (0.5633186939814676, 1.4757984674778055)
DRMM-TKS k = 30: (0.5213129611137165, -0.19344082448668362)
