# 生成数据集

In [1]:
import pickle
import numpy as np
from tqdm import tqdm

## 导入所需的排序列表

In [3]:
with open('./ranked list truncation/dataset/drmm_train.pkl', 'rb') as f:
    drmm_train = pickle.load(f)
with open('./ranked list truncation/dataset/drmm_test.pkl', 'rb') as f:
    drmm_test = pickle.load(f)
with open('./ranked list truncation/dataset/bm25_train.pkl', 'rb') as f:
    bm25_train = pickle.load(f)
with open('./ranked list truncation/dataset/bm25_test.pkl', 'rb') as f:
    bm25_test = pickle.load(f)
with open('./ranked list truncation/dataset/drmm_tks_train.pkl', 'rb') as f:
    tks_train = pickle.load(f)
with open('./ranked list truncation/dataset/drmm_tks_test.pkl', 'rb') as f:
    tks_test = pickle.load(f)

### 提取用到的所有doc id，在后续的计算中节省时间

In [4]:
from tqdm import tqdm
doc_set = set()
for dataset in tqdm([drmm_train, drmm_test, bm25_train, bm25_test, tks_train, tks_test]):
    for qid in dataset:
        docs = list(dataset[qid].keys())
        doc_set.update(docs)
print(len(doc_set))

100%|██████████| 6/6 [00:00<00:00, 261.00it/s]

105439





### 展示数据格式

In [5]:
print(type(drmm_train['612']))
for index, key in enumerate(drmm_train['612']):
    print(key, drmm_train['612'][key])
    if index == 4: break

<class 'dict'>
FT944-610 0.889916
FBIS3-25845 0.879761
LA030689-0004 0.871522
LA012189-0056 0.870184
LA030989-0104 0.864167


## 导入所需的统计特征

* 根据特征分析，我们认为document length和unique token的数量事实上可以被tf-idf特征所完整描述

* 如上所述，我们在这里去掉document length和unique token的统计，只做tf-idf和word2vec的统计

In [6]:
with open('./ranked list truncation/data_prep/statics/tfidf.pkl', 'rb') as f:
    tfidf = pickle.load(f)
with open('./ranked list truncation/data_prep/statics/doc2vec.pkl', 'rb') as f:
    doc2vec = pickle.load(f)

## tfidf稀疏表示的dense化

In [7]:
def iv2dense(iv, total: int=231448):
    dense = [0] * total
    for item in iv:
        dense[item[0]] = item[1]
    return dense

In [8]:
tfidf_dense = {}
for doc in tqdm(doc_set):
    tfidf_dense[doc] = iv2dense(tfidf[doc])

100%|██████████| 105439/105439 [06:37<00:00, 265.43it/s]


In [8]:
with open('./ranked list truncation/data_prep/statics/tfidf_dense.pkl', 'wb') as f:
    pickle.dump(tfidf_dense, f)

## Bicut统计特征

In [9]:
import os

os.mkdir('./ranked list truncation/dataset/bicut/bm25_train/')
for qid in tqdm(bm25_train):
    stats = []
    for doc in bm25_train[qid]:
        doc_stats = [bm25_train[qid][doc]] + tfidf_dense[doc]
        stats.append(doc_stats)
    with open('./ranked list truncation/dataset/bicut/bm25_train/{}.pkl'.format(qid), 'wb') as f:
        pickle.dump(stats, f)

os.mkdir('./ranked list truncation/dataset/bicut/bm25_test/')
for qid in tqdm(bm25_test):
    stats = []
    for doc in bm25_test[qid]:
        doc_stats = [bm25_test[qid][doc]] + tfidf_dense[doc]
        stats.append(doc_stats)
    with open('./ranked list truncation/dataset/bicut/bm25_test/{}.pkl'.format(qid), 'wb') as f:
        pickle.dump(stats, f)

os.mkdir('./ranked list truncation/dataset/bicut/drmm_train/')
for qid in tqdm(drmm_train):
    stats = []
    for doc in drmm_train[qid]:
        doc_stats = [drmm_train[qid][doc]] + tfidf_dense[doc]
        stats.append(doc_stats)
    with open('./ranked list truncation/dataset/bicut/drmm_train/{}.pkl'.format(qid), 'wb') as f:
        pickle.dump(stats, f)

os.mkdir('./ranked list truncation/dataset/bicut/drmm_test/')
for qid in tqdm(drmm_test):
    stats = []
    for doc in drmm_test[qid]:
        doc_stats = [drmm_test[qid][doc]] + tfidf_dense[doc]
        stats.append(doc_stats)
    with open('./ranked list truncation/dataset/bicut/drmm_test/{}.pkl'.format(qid), 'wb') as f:
        pickle.dump(stats, f)

os.mkdir('./ranked list truncation/dataset/bicut/drmm_tks_train/')
for qid in tqdm(tks_train):
    stats = []
    for doc in tks_train[qid]:
        doc_stats = [tks_train[qid][doc]] + tfidf_dense[doc]
        stats.append(doc_stats)
    with open('./ranked list truncation/dataset/bicut/drmm_tks_train/{}.pkl'.format(qid), 'wb') as f:
        pickle.dump(stats, f)

os.mkdir('./ranked list truncation/dataset/bicut/drmm_tks_test/')
for qid in tqdm(tks_test):
    stats = []
    for doc in tks_test[qid]:
        doc_stats = [tks_test[qid][doc]] + tfidf_dense[doc]
        stats.append(doc_stats)
    with open('./ranked list truncation/dataset/bicut/drmm_tks_test/{}.pkl'.format(qid), 'wb') as f:
        pickle.dump(stats, f)

100%|██████████| 190/190 [12:46<00:00,  4.03s/it]
100%|██████████| 50/50 [02:30<00:00,  3.02s/it]
100%|██████████| 193/193 [09:48<00:00,  3.05s/it]
100%|██████████| 50/50 [02:32<00:00,  3.05s/it]
100%|██████████| 194/194 [13:00<00:00,  4.02s/it]
100%|██████████| 49/49 [02:29<00:00,  3.05s/it]


## Attncut统计特征

In [10]:
def cos_simi(x, y):
    num = x.dot(y.T)
    denom = np.linalg.norm(x) * np.linalg.norm(y)
    sim = (num / denom) if denom != 0 else 0
    return sim if not np.isnan(sim) else 0

### 根据数据集制作排序列表

In [12]:
def ranked_list(dataset):
    rl = {}
    for qid in dataset:
        rl[qid] = list(dataset[qid].keys())
    return rl

def simi_docs(doc0, doc1):
    tfidf_doc0, tfidf_doc1 = np.array(tfidf_dense[doc0]), np.array(tfidf_dense[doc1])
    d2v_doc0, d2v_doc1 = np.array(doc2vec[doc0]), np.array(doc2vec[doc1])
    simi_tfidf = cos_simi(tfidf_doc0, tfidf_doc1)
    simi_d2v = cos_simi(d2v_doc0, d2v_doc1)
    return [simi_tfidf, simi_d2v]

def simi_list(dataset):
    rl = ranked_list(dataset)
    sl = {}
    for qid in tqdm(rl):
        sl[qid] = [simi_docs(rl[qid][0], rl[qid][1])]
        for i in range(1, 299):
            simi_0 = simi_docs(rl[qid][i-1], rl[qid][i])
            simi_1 = simi_docs(rl[qid][i], rl[qid][i+1])
            simi = [(simi_0[0] + simi_1[0]) / 2, (simi_0[1] + simi_1[1]) / 2]
            sl[qid].append(simi)
        sl[qid].append(simi_docs(rl[qid][298], rl[qid][299]))
    return sl

In [13]:
bm25_train_sl = simi_list(bm25_train)
with open('./ranked list truncation/dataset/attncut/bm25_train.pkl', 'wb') as f:
    pickle.dump(bm25_train_sl, f)

bm25_test_sl = simi_list(bm25_test)
with open('./ranked list truncation/dataset/attncut/bm25_test.pkl', 'wb') as f:
    pickle.dump(bm25_test_sl, f)

drmm_train_sl = simi_list(drmm_train)
with open('./ranked list truncation/dataset/attncut/drmm_train.pkl', 'wb') as f:
    pickle.dump(drmm_train_sl, f)

drmm_test_sl = simi_list(drmm_test)
with open('./ranked list truncation/dataset/attncut/drmm_test.pkl', 'wb') as f:
    pickle.dump(drmm_test_sl, f)

drmm_tks_train_sl = simi_list(tks_train)
with open('./ranked list truncation/dataset/attncut/drmm_tks_train.pkl', 'wb') as f:
    pickle.dump(drmm_tks_train_sl, f)

drmm_tks_test_sl = simi_list(tks_test)
with open('./ranked list truncation/dataset/attncut/drmm_tks_test.pkl', 'wb') as f:
    pickle.dump(drmm_tks_test_sl, f)

100%|██████████| 190/190 [2:35:33<00:00, 49.12s/it]
100%|██████████| 50/50 [41:02<00:00, 49.24s/it]
100%|██████████| 193/193 [2:41:30<00:00, 50.21s/it]
100%|██████████| 50/50 [41:00<00:00, 49.20s/it]
100%|██████████| 194/194 [2:39:04<00:00, 49.20s/it]
100%|██████████| 49/49 [40:12<00:00, 49.23s/it]
