In [1]:
import pandas as pd
import math
import numpy as np

In [2]:
def generate_tweetid_gain(file_path):
    data_dict = {}
    with open(file_path, 'r', errors='ignore') as f:
        for line in f:
            result = line.strip().split(' ')    # 171 Q0 305345146675949568 0
            query_id = result[0]                # 查询id
            docu_id = result[2]                 # 文档id
            rel = int(result[3])                # 文档相关性
            if query_id not in data_dict:       # 为每个query建立列表，记录query查询回来的排序后的文档以及相关性
                data_dict[query_id] = []
            if rel > 0:                         # 记录文档id和相关性
                data_dict[query_id].append([docu_id,rel])
            else:
                data_dict[query_id].append([docu_id,0])
    return data_dict

In [3]:
file_path = "./qrels.txt"
data_dict = generate_tweetid_gain(file_path)
data_dict

{'171': [['305345146675949568', 0],
  ['307360182604820481', 2],
  ['307575501373992961', 2],
  ['307585630601371648', 2],
  ['307592941277433856', 2],
  ['307462808801513472', 2],
  ['307581469876944897', 2],
  ['307380759835062272', 2],
  ['307431003394281472', 2],
  ['307496547812270080', 2],
  ['307364469183483906', 2],
  ['307381804225163264', 1],
  ['307510921696264193', 2],
  ['300274514149920768', 2],
  ['307587392213221376', 2],
  ['307577934087069697', 2],
  ['307519343858688000', 2],
  ['307261754851864576', 2],
  ['307665343382441985', 2],
  ['307597580177653760', 2],
  ['307410191257841664', 2],
  ['307655801345024001', 2],
  ['299228458914037760', 2],
  ['307615959626158080', 2],
  ['307557520413687809', 2],
  ['307397692240297984', 1],
  ['307551811961778176', 2],
  ['307488347935358976', 2],
  ['307568475902193664', 2],
  ['307549156946358273', 2],
  ['307540059509370880', 2],
  ['307494886838530048', 2],
  ['307695965991735296', 2],
  ['307491908907855873', 2],
  ['307

In [4]:
len(data_dict)

55

In [5]:
# 平均精度均值
def MAP_eval(data_dict):
    MAP = 0
    for query_result in data_dict:
        data = data_dict[query_result]              # 一条查询对应的所有文档及其相关性 query_id[[doc_id,rel]，[doc_id,rel],.....]
        AP = 0                                      # 初始化每一条query的AP为0
        Rank = []                                   # Rank记录每一条query中相关文档出现的位置，方便用来计算AP
        for result in data:                         # [doc_id,rel]
            rel = result[1]
            if(rel > 0):
                index = data.index(result) + 1      # 记录该doc是第几个出现的
                Rank.append(index)
        # print(Rank)
        num_related_doc = len(Rank)                 # 总的相关文档数
        i = 1
        for index in Rank:
            precision = i / index                   # 计算精度
            # print("precision= ",precision)
            AP += precision
            i += 1
        AP /= num_related_doc                       # 计算AP
        # print(AP)
        MAP += AP                   
    MAP = MAP / len(data_dict)                      # 计算所有query的AP,得到MAP
    return MAP

In [6]:
MAP = MAP_eval(data_dict)
print(MAP)

0.8772843634992499


In [7]:
# 平均倒数排名,对比MAP，当返回的相关结果较少时，使用它更加合适
def MRR_eval(data_dict):
    MRR = 0                                                       
    for query_id in data_dict:                      # query_id
        data = data_dict[query_id]                  # 一条查询对应的所有文档及其相关性 query_id[[doc_id,rel]，[doc_id,rel],.....]
        RR = 0                                      # 初始化每一条query的RR为0                     
        for result in data:                         # [doc_id,rel]
            rel = result[1]
            if(rel > 0):
                index = data.index(result) + 1      # 记录该doc是第几个出现的相关文档
                RR = 1 / index                      # 计算RR
                MRR += RR
                break
                  
    MRR = MRR / len(data_dict)                      # 计算所有query的RR,得到MRR
    return MRR

In [8]:
MRR_eval(data_dict)

0.79737012987013

In [9]:
import math


In [10]:
data = [['305345146675949568', 0], ['307381804225163264', 1], ['307397692240297984', 1], ['303752913380048896', 1], ['307498389086535681', 1], ['307332118504144896', 1], ['307469041541468160', 1], ['307009861714067456', 1], ['307539908522827777', 1], ['307466998923812864', 1], ['307360182604820481', 2], ['307575501373992961', 2], ['307585630601371648', 2], ['307592941277433856', 2], ['307462808801513472', 2], ['307581469876944897', 2], ['307380759835062272', 2], ['307431003394281472', 2], ['307496547812270080', 2], ['307364469183483906', 2], ['307510921696264193', 2], ['300274514149920768', 2], ['307587392213221376', 2], ['307577934087069697', 2], ['307519343858688000', 2], ['307261754851864576', 2], ['307665343382441985', 2], ['307597580177653760', 2], ['307410191257841664', 2], ['307655801345024001', 2], ['299228458914037760', 2], ['307615959626158080', 2], ['307557520413687809', 2], ['307551811961778176', 2], ['307488347935358976', 2], ['307568475902193664', 2], ['307549156946358273', 2], ['307540059509370880', 2], ['307494886838530048', 2], ['307695965991735296', 2], ['307491908907855873', 2], ['307611874365669377', 2], ['307613187208007680', 2], ['307613061387280385', 2], ['307610469299023873', 2], ['307617154998607873', 2], ['307615707976327168', 2], ['307603519337271297', 2], ['307610414743707648', 2], ['307606899916939264', 2], ['307595491422662656', 2], ['307630518067929088', 2], ['307349038309732352', 2], ['307407506920067073', 2], ['307592068874772480', 2], ['307690291090046976', 2], ['307622246900441089', 2], ['307533386363330560', 2], ['307446019006803969', 2], ['307278641111384064', 2], ['307335901766381568', 2], ['307497910965243904', 2], ['307490721923989505', 2], ['307424758105001984', 2], ['307628148269404161', 2], ['307513329222553600', 2], ['307610326663327745', 2], ['307600201642811392', 2], ['307634058056331265', 2], ['307634095784095746', 2], ['307634242605686785', 2], ['307388573810819072', 2], ['307402142384267264', 2], ['307416801489342464', 2], ['307497638310326272', 2], ['307656073949618176', 2], ['307557365186699264', 2], ['307562306101981184', 2], ['307513899614363649', 2], ['307409448899588096', 2]]
data.sort(key = lambda x:x[1])
print(data)

[['305345146675949568', 0], ['307381804225163264', 1], ['307397692240297984', 1], ['303752913380048896', 1], ['307498389086535681', 1], ['307332118504144896', 1], ['307469041541468160', 1], ['307009861714067456', 1], ['307539908522827777', 1], ['307466998923812864', 1], ['307360182604820481', 2], ['307575501373992961', 2], ['307585630601371648', 2], ['307592941277433856', 2], ['307462808801513472', 2], ['307581469876944897', 2], ['307380759835062272', 2], ['307431003394281472', 2], ['307496547812270080', 2], ['307364469183483906', 2], ['307510921696264193', 2], ['300274514149920768', 2], ['307587392213221376', 2], ['307577934087069697', 2], ['307519343858688000', 2], ['307261754851864576', 2], ['307665343382441985', 2], ['307597580177653760', 2], ['307410191257841664', 2], ['307655801345024001', 2], ['299228458914037760', 2], ['307615959626158080', 2], ['307557520413687809', 2], ['307551811961778176', 2], ['307488347935358976', 2], ['307568475902193664', 2], ['307549156946358273', 2], 

In [11]:
# 归一化折损累积增益,
def NDCG_eval(data_dict):
    NDCG = 0
    for query_result in data_dict:
        data = data_dict[query_result]                      # 一条查询对应的所有文档及其相关性 query_id[[doc_id,rel]，[doc_id,rel],.....]
        
        sort_data = data                                    # 计算IDCG准备的排序数组
        sort_data.sort(key = lambda x:x[1])                 # 按照文档评分rel排序
        # print(sort_data)
        CG = 0                                              # 初始化每一条query的CG为0
        DCG = 0                                             # 初始化每一条query的CG为0
        IDCG = 0                                            # 初始化每一条query的IDCG为0
        # 计算DCG
        for result in data:                                 # [doc_id,rel]
            rel = result[1]
            CG +=  rel                                      # CGn = sum(rel(i))
            i = data.index(result) + 1                      # 记录该doc是第几个出现的相关文档
            if i == 1:
                DCG = rel                                   # DCGn = rel1 + sum(2-n)rel(i) / log2(i)
            else:
                DCG += (rel / math.log2(i))
        # 计算IDCG
        for result in sort_data:                            # [doc_id,rel]
            rel = result[1]
            i = sort_data.index(result) + 1
            IDCG += ((2 ** rel - 1) / math.log2(i + 1))     # IDCG = sum((2**rel- 1) / log2(i))

        NDCG += DCG / IDCG
        
    NDCG = NDCG / len(data_dict)                            # 计算NDCG
    return NDCG

In [12]:
NDCG_eval(data_dict)

0.7859518965266727

In [13]:
def evaluation():
    # query relevance file
    file_path = './qrels.txt'
    data_dict = generate_tweetid_gain(file_path)
    MAP = MAP_eval(data_dict)
    print('MAP', ' = ', round(MAP, 5), sep='')
    MRR = MRR_eval(data_dict)
    print('MRR', ' = ', round(MRR, 5), sep='')
    NDCG = NDCG_eval(data_dict)
    print('NDCG', ' = ', round(NDCG, 5), sep='')

In [14]:
if __name__ == '__main__':
    evaluation()

MAP = 0.87728
MRR = 0.79737
NDCG = 0.78595
