In [1]:
import pandas as pd

data_origin = pd.read_csv("train.tsv", sep="\t", header=None, names=["q0","qa","label"])
data = data_origin.loc[data_origin["label"]==1].copy()
data

Unnamed: 0,q0,qa,label
0,喜欢打篮球的男生喜欢什么样的女生,爱打篮球的男生喜欢什么样的女生,1
1,我手机丢了，我想换个手机,我想买个新手机，求推荐,1
3,求秋色之空漫画全集,求秋色之空全集漫画,1
5,学日语软件手机上的,手机学日语的软件,1
7,侠盗飞车罪恶都市怎样改车,侠盗飞车罪恶都市怎么改车,1
...,...,...,...
238754,如何快速美白全身,怎样能快速全身美白,1
238755,这个表情叫什么,这个表情是什么,1
238758,世界上什么东西最小,世界上什么东西最小？,1
238762,求重生之老公请接招全文,求重生之老公请接招>全文,1


In [2]:
import numpy as np
import jieba
from zhon.hanzi import punctuation
import re

query_list = pd.DataFrame(set(np.array(data_origin[["q0","qa"]]).reshape(-1)), columns=["query"])
query_list = query_list.copy()
query_list["removepunc"] = query_list["query"].map(lambda x : re.sub(r"[%s><,./?;:'\"]+" %punctuation, "", x))
query_list["split"] = query_list["removepunc"].map(lambda x : list(jieba.cut(str(x))))
query_list

Building prefix dict from the default dictionary ...
Loading model from cache C:\Users\nagis\AppData\Local\Temp\jieba.cache
Loading model cost 0.457 seconds.
Prefix dict has been built successfully.


Unnamed: 0,query,removepunc,split
0,是桃花眼吗？,是桃花眼吗,"[是, 桃花, 眼, 吗]"
1,求筷子兄弟微电影《父亲》,求筷子兄弟微电影父亲,"[求, 筷子, 兄弟, 微, 电影, 父亲]"
2,红色的字是什么意思,红色的字是什么意思,"[红色, 的, 字, 是, 什么, 意思]"
3,哪里可以下载…微信4.2旧版本？请问,哪里可以下载微信42旧版本请问,"[哪里, 可以, 下载, 微信, 42, 旧版本, 请问]"
4,金毛起什么名字,金毛起什么名字,"[金毛, 起, 什么, 名字]"
...,...,...,...
245946,什么导航软件最好,什么导航软件最好,"[什么, 导航, 软件, 最好]"
245947,烧伤怎么处理,烧伤怎么处理,"[烧伤, 怎么, 处理]"
245948,意大利和乌拉圭谁能赢,意大利和乌拉圭谁能赢,"[意大利, 和, 乌拉圭, 谁, 能, 赢]"
245949,谁知道这个美女的名字,谁知道这个美女的名字,"[谁, 知道, 这个, 美女, 的, 名字]"


In [3]:
from tqdm import tqdm

query_list_short = query_list[["removepunc","split"]].copy().drop_duplicates(subset=['removepunc'])
query2idx = {}
query2query = {}

for idx in tqdm(query_list_short.index):
    query2idx[query_list_short.loc[idx]["removepunc"]] = idx
    
for idx in tqdm(query_list.index):
    query2query[query_list.loc[idx]["query"]] = query_list.loc[idx]["removepunc"]

100%|██████████| 225884/225884 [00:04<00:00, 48822.63it/s]
100%|██████████| 245951/245951 [00:12<00:00, 20267.23it/s]


In [4]:
data["qa_encode"] = data["qa"].map(lambda x : query2idx[query2query[x]])
data["q0_encode"] = data["q0"].map(lambda x : query2idx[query2query[x]])
data = data.drop_duplicates(subset=["qa_encode","q0_encode"])
data = data.loc[data["qa_encode"] != data["q0_encode"]]
data

Unnamed: 0,q0,qa,label,qa_encode,q0_encode
0,喜欢打篮球的男生喜欢什么样的女生,爱打篮球的男生喜欢什么样的女生,1,154608,206341
1,我手机丢了，我想换个手机,我想买个新手机，求推荐,1,94091,224714
3,求秋色之空漫画全集,求秋色之空全集漫画,1,2567,41172
5,学日语软件手机上的,手机学日语的软件,1,50891,130904
7,侠盗飞车罪恶都市怎样改车,侠盗飞车罪恶都市怎么改车,1,124814,228073
...,...,...,...,...,...
238748,上海哪里收驾照分？,上海驾照分哪里收？,1,47688,19485
238749,我的世界带什么去挖铁矿,我的世界里铁矿去哪挖,1,236236,238016
238750,舅舅的儿子叫什么,我的舅舅叫我的儿子叫什么啊,1,171115,116085
238754,如何快速美白全身,怎样能快速全身美白,1,36367,154430


In [18]:
def get_root(query_list: pd.DataFrame, pos: int):
    if query_list["root"][pos] == pos:
        return pos
    return get_root(query_list, query_list["root"][pos])

def union(query_list: pd.DataFrame, pos1: int, pos2: int):
    query_list.at[get_root(query_list,pos1),"root"] = get_root(query_list,pos2)

In [19]:
query_list_short["root"] = query_list_short.index

for idx in tqdm(data.index):
    line = data.loc[idx]
    union(query_list_short, line["qa_encode"], line["q0_encode"])

100%|██████████| 117762/117762 [00:09<00:00, 12154.39it/s]


In [21]:
import math

class BM25(object):
    def __init__(self, docs):
        self.D = len(docs)
        self.avgdl = sum([len(doc)+0.0 for doc in docs]) / self.D
        self.docs = docs
        self.f = []  # 列表的每一个元素是一个dict，dict存储着一个文档中每个词的出现次数
        self.df = {} # 存储每个词及出现了该词的文档数量
        self.idf = {} # 存储每个词的idf值
        self.k1 = 1.5
        self.b = 0.75
        self.init()

    def init(self):
        for doc in self.docs:
            tmp = {}
            for word in doc:
                tmp[word] = tmp.get(word, 0) + 1  # 存储每个文档中每个词的出现次数
            self.f.append(tmp)
            for k in tmp.keys():
                self.df[k] = self.df.get(k, 0) + 1
        for k, v in self.df.items():
            self.idf[k] = math.log(self.D-v+0.5)-math.log(v+0.5)

    def sim(self, doc, index):
        score = 0
        for word in doc:
            if word not in self.f[index]:
                continue
            d = len(self.docs.iloc[index])
            score += (self.idf[word]*self.f[index][word]*(self.k1+1)
                      / (self.f[index][word]+self.k1*(1-self.b+self.b*d
                                                      / self.avgdl)))
        return score

    # 总共有N篇文档，传来的doc为查询文档，计算doc与所有文档匹配
    # 后的得分score，总共有多少篇文档，scores列表就有多少项，
    # 每一项为doc与这篇文档的得分，所以分清楚里面装的是文档得分，
    # 不是词语得分。
    def simall(self, doc):
        scores = []
        for index in range(self.D):
            score = self.sim(doc, index)
            scores.append([score,self.docs.index[index]])
        scores.sort(key = lambda x: -x[0])
        return scores

In [22]:
bm25 = BM25(query_list_short["split"])

In [23]:
data_batch_0 = data.iloc[:2000].copy()
data_batch_remain = data.iloc[5:].copy()

In [26]:
q_neg_bm25 = []
for idx in tqdm(data_batch_0.index):
    line = data_batch_0.loc[idx]
    li_bm25 = list(map(lambda x : x[1], bm25.simall(query_list_short["split"].loc[line["q0_encode"]])))
    
    bm25_final_list = []
    for bm25query in li_bm25:
        if get_root(query_list_short, bm25query) == get_root(query_list_short, line["q0_encode"]):
            continue
        
        bm25_final_list.append(bm25query)
        
        if len(bm25_final_list) >= 5:
            break
        
    q_neg_bm25.append(bm25_final_list)

100%|██████████| 5/5 [00:03<00:00,  1.42it/s]


In [27]:
data_batch_0["qn"] = q_neg_bm25
data_batch_0 = data_batch_0.copy().dropna()
data_batch_0["qn1"] = data_batch_0["qn"].map(lambda x : query_list_short["removepunc"].loc[x[0]])
data_batch_0["qn2"] = data_batch_0["qn"].map(lambda x : query_list_short["removepunc"].loc[x[1]])
data_batch_0["qn3"] = data_batch_0["qn"].map(lambda x : query_list_short["removepunc"].loc[x[2]])
data_batch_0["qn4"] = data_batch_0["qn"].map(lambda x : query_list_short["removepunc"].loc[x[3]])
data_batch_0["qn5"] = data_batch_0["qn"].map(lambda x : query_list_short["removepunc"].loc[x[4]])
data_batch_0

Unnamed: 0,q0,qa,label,qa_encode,q0_encode,qn,qn1,qn2,qn3,qn4,qn5
0,喜欢打篮球的男生喜欢什么样的女生,爱打篮球的男生喜欢什么样的女生,1,154608,206341,"[202155, 35903, 58382, 117745, 11514]",女生喜欢打篮球的男生吗,女生喜欢看男生打篮球吗,男生喜欢什么样的女生,女生喜欢什么样的男生,女生都喜欢什么样的男生
1,我手机丢了，我想换个手机,我想买个新手机，求推荐,1,94091,224714,"[8368, 178456, 132240, 151491, 165385]",我手机丢了怎么找回,我的苹果手机丢了怎么找回,我的小米手机丢了怎么找回来,我的手机丢了能找回来吗,我的世界手机版走丢了怎么办
3,求秋色之空漫画全集,求秋色之空全集漫画,1,2567,41172,"[55484, 73914, 129821, 156484, 167390]",求秋色之空动漫全集,求秋色之空动画全集,求秋色之空全集动画,求秋色之空的动漫全集,动漫秋色之空
5,学日语软件手机上的,手机学日语的软件,1,50891,130904,"[79258, 2151, 72495, 83946, 207089]",手机学日语,学日语哪个软件好,什么软件可以学日语,好的学日语的软件,学日语用什么软件好
7,侠盗飞车罪恶都市怎样改车,侠盗飞车罪恶都市怎么改车,1,124814,228073,"[235111, 10829, 56527, 104560, 229487]",侠盗飞车罪恶都市,侠盗飞车罪恶都市秘籍,侠盗飞车修改器罪恶都市,侠盗飞车之罪恶都市,侠盗飞车罪恶都市修改器


In [38]:
len(data)

117762

In [46]:
with open("data.csv", "w") as F:
    F.write("idx,q0,qa,qn1,qn2,qn3,qn4,qn5\n")

data_batch_remain = data.copy()
used_list = []

batchsize = 200

while len(data_batch_remain) > 0:
    data_batch_working = data_batch_remain.iloc[:batchsize].copy()
    data_batch_remain = data_batch_remain.iloc[batchsize:].copy()
    
    q_neg_bm25 = []
    for idx in tqdm(data_batch_working.index):
        line = data_batch_working.loc[idx]
        q0root = get_root(query_list_short, line["q0_encode"])
        if q0root in used_list:
            q_neg_bm25.append(None)
            continue
        
        used_list.append(q0root)
        
        li_bm25 = list(map(lambda x : x[1], bm25.simall(query_list_short["split"].loc[line["q0_encode"]])))
        
        bm25_final_list = []
        for bm25query in li_bm25:
            if get_root(query_list_short, bm25query) == q0root:
                continue
            
            bm25_final_list.append(bm25query)
            
            if len(bm25_final_list) >= 5:
                break
            
        q_neg_bm25.append(bm25_final_list)

    data_batch_working["qn"] = q_neg_bm25
    data_batch_working = data_batch_working.copy().dropna()
    data_batch_working["qn1"] = data_batch_working["qn"].map(lambda x : query_list["removepunc"].loc[x[0]])
    data_batch_working["qn2"] = data_batch_working["qn"].map(lambda x : query_list["removepunc"].loc[x[1]])
    data_batch_working["qn3"] = data_batch_working["qn"].map(lambda x : query_list["removepunc"].loc[x[2]])
    data_batch_working["qn4"] = data_batch_working["qn"].map(lambda x : query_list["removepunc"].loc[x[3]])
    data_batch_working["qn5"] = data_batch_working["qn"].map(lambda x : query_list["removepunc"].loc[x[4]])
    data_batch_working[["q0","qa","qn1","qn2","qn3","qn4","qn5"]].to_csv("data.csv", mode='a', header=False)
    
    data_batch_remain.to_csv("remain.csv")

100%|██████████| 200/200 [02:34<00:00,  1.29it/s]
100%|██████████| 200/200 [02:31<00:00,  1.32it/s]
100%|██████████| 200/200 [02:37<00:00,  1.27it/s]
100%|██████████| 200/200 [02:29<00:00,  1.34it/s]
100%|██████████| 200/200 [02:36<00:00,  1.28it/s]
100%|██████████| 200/200 [02:33<00:00,  1.30it/s]
100%|██████████| 200/200 [02:22<00:00,  1.41it/s]
100%|██████████| 200/200 [02:22<00:00,  1.41it/s]
100%|██████████| 200/200 [02:29<00:00,  1.34it/s]
100%|██████████| 200/200 [02:10<00:00,  1.53it/s]
100%|██████████| 200/200 [02:28<00:00,  1.35it/s]
100%|██████████| 200/200 [02:21<00:00,  1.41it/s]
100%|██████████| 200/200 [02:08<00:00,  1.55it/s]
100%|██████████| 200/200 [02:08<00:00,  1.55it/s]
100%|██████████| 200/200 [02:09<00:00,  1.54it/s]
100%|██████████| 200/200 [02:19<00:00,  1.44it/s]
100%|██████████| 200/200 [02:12<00:00,  1.51it/s]
100%|██████████| 200/200 [02:21<00:00,  1.41it/s]
100%|██████████| 200/200 [02:26<00:00,  1.37it/s]
100%|██████████| 200/200 [02:28<00:00,  1.35it/s]


In [None]:
data_batch_working