In [1]:
# Note: SearchEngine is an alias for the SparseRetriever
from retriv import SearchEngine
import csv
from typing import Dict
import csv 
import collections
import jieba
csv.field_size_limit(500 * 1024 * 1024)
from tqdm import tqdm

### 去掉停用词

In [2]:
stopwords = set()
with open("../cn_stopwords.txt", "r", encoding="utf-8") as f:
    for line in f:
        stopwords.add(line.strip())

### 处理 corpus.tsv 并构造bm25搜索引擎


In [3]:
## 对corpus 进行分词
pid_to_passage = []
datapath = "corpus.tsv"

def create_id_item_dict(file_path: str, delimiter: str = "\t") -> Dict[int, str]:  ## note:数据准

    print(file_path)
    with open(file_path, encoding="utf-8") as f:

        for idx, (qid, passage) in enumerate(tqdm(csv.reader(f, delimiter=delimiter))):  ## ['qid', 'pid', 'index']

            pid_to_passage.append((qid, passage, [word for word in jieba.cut(passage) if word not in stopwords]))


create_id_item_dict(datapath)
print("分词完成")
pid_to_passage_to_save = []
for data in pid_to_passage:
    data_temp = {}
    data_temp["id"] = data[0]
    data_temp["text"] = " ".join(data[2])
    pid_to_passage_to_save.append(data_temp)
print(f"courpus_len == {len(pid_to_passage)}")
pid_to_passage_to_save[0]


corpus.tsv


0it [00:00, ?it/s]Building prefix dict from the default dictionary ...
Loading model from cache /tmp/jieba.cache
Loading model cost 1.032 seconds.
Prefix dict has been built successfully.
1000000it [04:31, 3688.29it/s]


分词完成
courpus_len == 1000000


{'id': '1637165', 'text': '赵雷 原创 歌曲   张韶涵 唱火'}

In [4]:
## 构建搜索引擎
serachengine_obj = SearchEngine(model="bm25").index(pid_to_passage_to_save)

Building TDF matrix: 100%|██████████| 1000000/1000000 [00:30<00:00, 32724.59it/s]
Building inverted index: 100%|██████████| 586321/586321 [03:11<00:00, 3062.42it/s]


#### 加载 训练集数据 "train.queries.tsv" 并进行搜索 保存数据

In [5]:
train_qid_to_query = []
datapath = "train.queries.tsv"


def create_id_item_dict(file_path: str, delimiter: str = "\t") -> Dict[int, str]:  ## note:数据准

    print(file_path)
    with open(file_path, encoding="utf-8") as f:

        for idx, (qid, passage) in enumerate(csv.reader(f, delimiter=delimiter)):  ## ['qid', 'pid', 'index']

            train_qid_to_query.append((qid, passage, [word for word in jieba.cut(passage) if word not in stopwords]))


create_id_item_dict(datapath)
print("分词完成")
train_qid_to_query_to_save = []
for data in train_qid_to_query:
    data_temp = {}
    data_temp["id"] = data[0]
    data_temp["text"] = " ".join(data[2])
    train_qid_to_query_to_save.append(data_temp)
print(f"train_qid_len == {len(train_qid_to_query)}")
train_qid_to_query_to_save[0]


train.queries.tsv
分词完成
train_qid_len == 96279


{'id': '6', 'text': '谍战 电视剧 战争'}

In [6]:
## 搜索结果
train_bm25_top100_res = serachengine_obj.bsearch(train_qid_to_query_to_save,cutoff=100,batch_size= 10000)
print(f"train_bm25_len == {len(train_bm25_top100_res)}")


Batch search: 100%|██████████| 96279/96279 [02:37<00:00, 612.84it/s]

train_bm25_len == 96279





In [7]:
save_path = "train.top100.bm25.tsv"
print(f"save to {save_path}")

fp = open(save_path, "w", newline="")
writer = csv.writer(fp, delimiter="\t")
count = 0
writer.writerow(["qid", "pid", "rank", "score"])
for qid, rels_ in tqdm(train_bm25_top100_res.items()):
    rank = 1
    for pid, score in rels_.items():
        writer.writerow([qid, pid, rank, score])
        rank += 1
        count += 1
fp.close()
print(f"count = {count}")

save to train.top100.bm25.tsv


100%|██████████| 96279/96279 [00:31<00:00, 3023.69it/s]

count = 8958455





### 加载验证数据 dev_queries.tsv  并进行搜索 保存数据

In [8]:
dev_qid_to_query = []
datapath ="dev.queries.tsv"
def create_id_item_dict(file_path: str, delimiter: str = "\t") -> Dict[int, str]: ## note:数据准
    
    print(file_path)
    with open(file_path, encoding="utf-8") as f:
        for idx,(qid,passage) in  enumerate(csv.reader(f, delimiter=delimiter)):  ## ['qid', 'pid', 'index']
            dev_qid_to_query.append((qid,passage,[word for word in jieba.cut(passage) if word not in stopwords]))

create_id_item_dict(datapath,)
dev_pid_to_query_to_save=[]
for data in tqdm(dev_qid_to_query):
    data_temp = {}
    data_temp["id"]=data[0]
    data_temp["text"]= " ".join(data[2])
    dev_pid_to_query_to_save.append(data_temp)
dev_pid_to_query_to_save[:10]
print(len(dev_pid_to_query_to_save))

dev.queries.tsv


100%|██████████| 1000/1000 [00:00<00:00, 199188.11it/s]

1000





In [9]:
dev_bm25_top100_res = serachengine_obj.bsearch(dev_pid_to_query_to_save,cutoff=100,batch_size= 10000)

print(f"dev_bm25_len == {len(dev_bm25_top100_res)}")

Batch search: 100%|██████████| 1000/1000 [00:01<00:00, 698.69it/s]

dev_bm25_len == 1000





In [10]:
save_path = "dev.top100.bm25.tsv"
print(f"save to {save_path}")

fp = open(save_path, "w", newline="")
writer = csv.writer(fp, delimiter="\t")
count = 0
writer.writerow(["qid", "pid", "rank","score"])
for qid,v in tqdm(dev_bm25_top100_res.items()):
    rank = 1
    for pid,score in v.items():
        writer.writerow([qid, pid, rank,score])
        rank += 1
        count+=1
fp.close()
print(f"count_len={count}")

save to dev.top100.bm25.tsv


100%|██████████| 1000/1000 [00:00<00:00, 3161.37it/s]

count_len=92342





## test atutotune

In [None]:
qrels_obj = {}
datapath = "dev.qrels.tsv"

def get_qrels(file_path: str, delimiter: str = "\t") -> Dict[int, str]:  ## note:数据准

    print(file_path)
    with open(file_path, encoding="utf-8") as f:

        for idx, (qid ,_, pid,rel) in enumerate(tqdm(csv.reader(f, delimiter=delimiter))):  ## ['qid', 'pid', 'index']
               if  qid in qrels_obj.keys():
                    qrels_obj[qid][pid]=rel
               else:
                    qrels_obj[qid]={}
                    qrels_obj[qid][pid]=rel


get_qrels(datapath)
qrels_obj

In [18]:
serachengine_obj.autotune(
  queries=dev_pid_to_query_to_save,  # Train queries
  qrels=qrels_obj,      # Train qrels
  metric="ndcg",  # Default value, metric to maximize
  n_trials=100,   # Default value, number of trials
  cutoff=100,     # Default value, number of results
)

  0%|          | 0/100 [00:00<?, ?it/s]

In [21]:
serachengine_obj.hyperparams
## BM25 公式中包含 3 个自由调节参数 ，除了调节因子 b 外 ，还有针对词频的调节因子 k1和 k2。 k1的作用是对查询词在文档中的词频进行调节，如果将 k1设定为 0，则第二部分计算因子成了整数 1，即不考虑词频的因素，退化成了二元独立模型。 如果将 k1设定为较大值， 则第二部分计算因子基本和词频 fi保持线性增长，即放大了词频的权值，根据经验，一般将 k1设定为 1.2。


{'b': 0.85, 'k1': 0.2}