In [1]:
from functools import partial
import argparse
import os
import sys
import random
import time

import numpy as np
import hnswlib
import paddle
import paddle.nn.functional as F
from paddlenlp.data import Stack, Tuple, Pad
from paddlenlp.datasets import load_dataset, MapDataset
from paddlenlp.utils.log import logger
import paddlenlp

from base_model import SemanticIndexBase
from data2 import convert_example, create_dataloader
from data2 import gen_id2corpus
from ann_util import build_index

  from .autonotebook import tqdm as notebook_tqdm
You can set full_graph=True, then you can assign input spec.



In [2]:
paddle.set_device("gpu") 

Place(gpu:0)

In [None]:
# 下面开始构建模型并载入模型参数

In [3]:
model_name = "ernie-1.0"

pretrained_model = paddlenlp.transformers.AutoModel.from_pretrained(model_name)

[32m[2025-02-06 22:38:50,806] [    INFO][0m - We are using <class 'paddlenlp.transformers.ernie.modeling.ErnieModel'> to load 'ernie-1.0'.[0m
[32m[2025-02-06 22:38:50,808] [    INFO][0m - Already cached C:\Users\LJX\.paddlenlp\models\ernie-1.0\model_state.pdparams[0m
[32m[2025-02-06 22:38:50,809] [    INFO][0m - Loading weights file model_state.pdparams from cache at C:\Users\LJX\.paddlenlp\models\ernie-1.0\model_state.pdparams[0m
[32m[2025-02-06 22:38:51,163] [    INFO][0m - Loaded weights file from disk, setting weights to model.[0m
- This IS expected if you are initializing ErnieModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ErnieModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification 

In [4]:
model = SemanticIndexBase(pretrained_model, output_emb_size=256)

In [5]:
params_path = "model_param/model_180/model_state.pdparams" 

if params_path and os.path.isfile(params_path): 
    state_dict = paddle.load(params_path) 
    model.set_dict(state_dict) 
    print("Loaded parameters from %s" % params_path) 
else:
    raise ValueError("Please set params_path with correct pretrained model file")

Loaded parameters from model_param/model_180/model_state.pdparams


In [None]:
# 下面加载语料库文件，并利用语料库中的数据来构造ANN索引库

In [6]:
tokenizer = paddlenlp.transformers.AutoTokenizer.from_pretrained(model_name)

[32m[2025-02-06 22:39:05,394] [    INFO][0m - We are using (<class 'paddlenlp.transformers.ernie.tokenizer.ErnieTokenizer'>, False) to load 'ernie-1.0'.[0m
[32m[2025-02-06 22:39:05,395] [    INFO][0m - Already cached C:\Users\LJX\.paddlenlp\models\ernie-1.0\vocab.txt[0m
[32m[2025-02-06 22:39:05,405] [    INFO][0m - tokenizer config file saved in C:\Users\LJX\.paddlenlp\models\ernie-1.0\tokenizer_config.json[0m
[32m[2025-02-06 22:39:05,407] [    INFO][0m - Special tokens file saved in C:\Users\LJX\.paddlenlp\models\ernie-1.0\special_tokens_map.json[0m


In [7]:
trans_func = partial(convert_example, tokenizer=tokenizer, max_seq_length=60)

In [8]:
def batchify_fn(samples):
    fn = Tuple(
        Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype="int64"),  
        Pad(axis=0, pad_val=tokenizer.pad_token_type_id, dtype="int64"),  
    )

    processed_samples = fn(samples) 

    result = []
    for data in processed_samples:
        result.append(data) 

    return result

In [9]:
corpus_file = "recall_dataset/corpus.csv" 

id2corpus = gen_id2corpus(corpus_file) 

In [10]:
print(type(id2corpus))

<class 'dict'>


In [11]:
for i in range(10):
    print(id2corpus[i]) 

2002-2017年我国法定传染病发病率和死亡率时间变化趋势传染病,发病率,死亡率,病死率
陕西省贫困地区城乡青春期少女生长发育调查青春期,生长发育,贫困地区
五丈岩水库溢洪道加固工程中的新材料应用碳纤维布,粘钢加固技术,超细水泥,灌浆技术
木塑复合材料在儿童卫浴家具中的应用探索木塑复合材料,儿童,卫浴家具
泡沫铝准静态轴向压缩有限元仿真泡沫铝,准静态,轴向压缩,力学特性
An Analysis of the Potential of Import and Export Trade between China and the Countries along the Belt and Road
个体收入剥夺、医疗保障对我国城乡居民健康的影响城乡居民健康;个体收入剥夺;医疗保障;Logit模型
前列腺癌规范化标本取材及病理诊断共识
含碳增强体镁基复合材料的制备和界面调控的研究现状及发展趋势评述,复合材料,镁基,增强体,分散,界面,性能
小剂量氯胺酮对剖宫产术产后抑郁症的预防作用研究氯胺酮,剖宫产,抑郁症,产后并发症,小剂量


In [12]:
corpus_list = []
for idx, text in id2corpus.items():
    corpus_list.append({idx: text}) 

In [13]:
for i in range(10):
    print(corpus_list[i])

{0: '2002-2017年我国法定传染病发病率和死亡率时间变化趋势传染病,发病率,死亡率,病死率'}
{1: '陕西省贫困地区城乡青春期少女生长发育调查青春期,生长发育,贫困地区'}
{2: '五丈岩水库溢洪道加固工程中的新材料应用碳纤维布,粘钢加固技术,超细水泥,灌浆技术'}
{3: '木塑复合材料在儿童卫浴家具中的应用探索木塑复合材料,儿童,卫浴家具'}
{4: '泡沫铝准静态轴向压缩有限元仿真泡沫铝,准静态,轴向压缩,力学特性'}
{5: 'An Analysis of the Potential of Import and Export Trade between China and the Countries along the Belt and Road'}
{6: '个体收入剥夺、医疗保障对我国城乡居民健康的影响城乡居民健康;个体收入剥夺;医疗保障;Logit模型'}
{7: '前列腺癌规范化标本取材及病理诊断共识'}
{8: '含碳增强体镁基复合材料的制备和界面调控的研究现状及发展趋势评述,复合材料,镁基,增强体,分散,界面,性能'}
{9: '小剂量氯胺酮对剖宫产术产后抑郁症的预防作用研究氯胺酮,剖宫产,抑郁症,产后并发症,小剂量'}


In [14]:
corpus_ds = MapDataset(corpus_list) 

In [15]:
print(type(corpus_ds))

<class 'paddlenlp.datasets.dataset.MapDataset'>


In [16]:
for i in range(10):
    print(corpus_ds[i])

{0: '2002-2017年我国法定传染病发病率和死亡率时间变化趋势传染病,发病率,死亡率,病死率'}
{1: '陕西省贫困地区城乡青春期少女生长发育调查青春期,生长发育,贫困地区'}
{2: '五丈岩水库溢洪道加固工程中的新材料应用碳纤维布,粘钢加固技术,超细水泥,灌浆技术'}
{3: '木塑复合材料在儿童卫浴家具中的应用探索木塑复合材料,儿童,卫浴家具'}
{4: '泡沫铝准静态轴向压缩有限元仿真泡沫铝,准静态,轴向压缩,力学特性'}
{5: 'An Analysis of the Potential of Import and Export Trade between China and the Countries along the Belt and Road'}
{6: '个体收入剥夺、医疗保障对我国城乡居民健康的影响城乡居民健康;个体收入剥夺;医疗保障;Logit模型'}
{7: '前列腺癌规范化标本取材及病理诊断共识'}
{8: '含碳增强体镁基复合材料的制备和界面调控的研究现状及发展趋势评述,复合材料,镁基,增强体,分散,界面,性能'}
{9: '小剂量氯胺酮对剖宫产术产后抑郁症的预防作用研究氯胺酮,剖宫产,抑郁症,产后并发症,小剂量'}


In [17]:
batch_sampler = paddle.io.BatchSampler(corpus_ds, batch_size=64, shuffle=False)

corpus_data_loader = paddle.io.DataLoader(dataset=corpus_ds.map(trans_func), batch_sampler=batch_sampler, collate_fn=batchify_fn, return_list=True)

In [18]:
#如果需要从头构建索引，就运行下面这段代码。

output_emb_size = 256
hnsw_max_elements = 1000000 
hnsw_ef = 100 
hnsw_m = 100 

final_index = build_index(output_emb_size, hnsw_max_elements, hnsw_ef, hnsw_m, corpus_data_loader, model)

save_index_dir = "index_file" 
if not os.path.exists(save_index_dir):
    os.makedirs(save_index_dir)

save_index_path = os.path.join(save_index_dir, "final_index.bin") 
final_index.save_index(save_index_path)

[32m[2025-02-06 22:40:50,857] [    INFO][0m - start build index..........[0m
[32m[2025-02-06 22:49:45,512] [    INFO][0m - Total index number:300000[0m


In [None]:
'''
#如果有现成的索引文件final_index.bin,就运行这段代码
save_index_path = "index_file/final_index.bin"
output_emb_size = 256
final_index = hnswlib.Index(space="ip", dim=output_emb_size) 
final_index.load_index(save_index_path) 
'''

In [None]:
# 下面获取验证数据集中的所有query

In [19]:
def get_query_text(similar_text_pair_file): 
    querys = []
    with open(similar_text_pair_file, "r", encoding="utf-8") as f:
        for line in f:
            splited_line = line.rstrip().split("\t") 
            if len(splited_line) != 2: 
                continue

            if not splited_line[0] or not splited_line[1]: 
                continue

            querys.append({"text": splited_line[0]}) 

    return querys

In [20]:
similar_text_pair_file = "recall_dataset/dev.csv" 

query_list = get_query_text(similar_text_pair_file) 

In [21]:
print(type(query_list))

<class 'list'>


In [22]:
print(query_list)

[{'text': '热处理对尼龙6 及其与聚酰胺嵌段共聚物共混体系晶体熔融行为和结晶结构的影响'}, {'text': '面向生态系统服务的生态系统分类方案研发与应用.'}, {'text': 'huntington舞蹈病的动物模型'}, {'text': '试论我国海岸带经济开发的问题与前景'}, {'text': '外语阅读焦虑与英语成绩及性别的关系'}, {'text': '加油站风险分级管控'}, {'text': '溃疡性结肠炎 结肠癌'}, {'text': '雌激素受体 子宫内膜息肉 单发'}, {'text': '李雪梅护理甲状腺'}, {'text': '腐败治理下的监督与激励'}, {'text': '基于不平衡文本数据挖掘'}, {'text': '颈丛神经阻滞 AND 输液港'}, {'text': '脊柱肿瘤术后护理'}, {'text': '肺腺癌和小细胞癌鉴别诊断'}, {'text': 'P53 浆液性癌'}, {'text': '技术路线图--一种新型技术管理工具'}, {'text': '中国胰腺癌综合诊治指南'}, {'text': '脐血造血干细胞移植治疗白血病的研究进 展'}, {'text': '舰船动力系统仿真'}, {'text': '5-氨基酮戊酸光动力疗法治疗尖锐湿疣的临床研究'}, {'text': '财务杠杆和经营杠杆的优化路径'}, {'text': '急性心肌梗死 住院时间'}, {'text': '恶性胸腔积液 恩度'}, {'text': '电力外协施工队伍安全资质管控体系的研究与应用'}, {'text': '蠕墨铸铁 热分析'}, {'text': '马荣. 探讨心理护理干预对老年慢性心力衰竭患者认知功能状况及心功能的影响'}, {'text': 'shuanghuanglian'}, {'text': '早期肝癌 AFU'}, {'text': '区块链 财务造假'}, {'text': '传统的老年人心理咨询'}, {'text': 'Donkey skin'}, {'text': '小学生 学业情绪'}, {'text': '脊髓型颈椎病 神经恢复'}, {'text': '聚乳酸-羟基乙酸共聚物'}, {'text': '族裔聚居区的经济与社会——对聚居区族裔经济理论的检视与反思'}, {

In [23]:
query_ds = MapDataset(query_list) 

In [24]:
batch_sampler = paddle.io.BatchSampler(query_ds, batch_size=64, shuffle=False)

query_data_loader = paddle.io.DataLoader(dataset=query_ds.map(trans_func), batch_sampler=batch_sampler, collate_fn=batchify_fn, return_list=True)

In [25]:
query_embedding = model.get_semantic_embedding(query_data_loader) 

In [None]:
# 下面针对验证集中的query进行召回，生成召回结果文件

In [26]:
recall_result_dir = "recall_result_file" 
 
if not os.path.exists(recall_result_dir): 
    os.mkdir(recall_result_dir)

recall_result_file = "recall_result.txt" 

recall_result_file = os.path.join(recall_result_dir, recall_result_file) 

In [28]:
# 下面正式开始召回

with open(recall_result_file, "w", encoding="utf-8") as f: 
    for batch_index, batch_query_embedding in enumerate(query_embedding): 
        recalled_idx, cosine_sims = final_index.knn_query(batch_query_embedding.numpy(), 50) 

        batch_size = len(cosine_sims)

        for row_index in range(batch_size):
            text_index = 64 * batch_index + row_index 
            for idx, doc_idx in enumerate(recalled_idx[row_index]):
                f.write( "{}\t{}\t{}\n".format(query_list[text_index]["text"], id2corpus[doc_idx], 1.0 - cosine_sims[row_index][idx] ) )