In [3]:
import torch
from transformers import AutoTokenizer, AutoModel
import json
from tqdm import tqdm  # 导入tqdm库

# 加载模型 - Load model
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-mpnet-base-v2", cache_dir='embedding_model', model_max_length=512)
model = AutoModel.from_pretrained("sentence-transformers/all-mpnet-base-v2", cache_dir='embedding_model')

# 平均池化 - Average pooling
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0]  # First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

# # 获取文本的向量 - Get the vector of the text
def get_vector(text):
    encoded_input = tokenizer(text, padding=True, truncation=True, max_length=512, return_tensors='pt')
    with torch.no_grad():
        model_output = model(**encoded_input)
    # 使用平均池化获取文本向量 - Use average pooling to get text vectors
    input_ids = mean_pooling(model_output, encoded_input['attention_mask'])
    return input_ids

# 读取本地数据库并转换为二维向量存储 - Read the local database and convert it to a two-dimensional vector storage
def read_and_save(file):
    file = json.load(open(file, 'r', encoding='utf-8'))
    vectors = []
    for i in tqdm(range(len(file))):  # 使用tqdm来包裹循环
        line = str(file[i])
        vector = get_vector(line)
        vectors.append(vector)
    vectors = torch.cat(vectors, dim=0)
    torch.save(vectors, './local_vectors/vectors.pt') # 保存向量的路径及文件名
    print('generate vectors successfully!')

if __name__ == '__main__':
    read_and_save('gushiwen.json')

Downloading (…)lve/main/config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

100%|██████████| 108197/108197 [4:38:59<00:00,  6.46it/s] 


generate vectors successfully!


In [13]:
import torch
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import json

# 读取本地数据库 - Read the local database
def read_local_vectors():
    vectors = torch.load('./local_vectors/vectors.pt')
    return vectors

# 将输入文本转化为向量并与数据库中的向量进行比较 - Convert the input text into a vector and compare it with the vector in the database
# 输出相似度最高的前n个文本的序号 - Output the serial number of the top 5 texts with the highest similarity
def get_domain_knowledge(text, n, threshold=0.2):
    # 参数介绍 - Parameter introduction
    # text: 输入文本 - Input text
    # n: 输出相似度最高的前n个文本的序号 - Output the serial number of the top 5 texts with the highest similarity
    # threshold: 概率阈值，小于该阈值的知识将被忽略 - Probability threshold, texts with probability less than this threshold will be ignored
    
    # 读取数据库中的向量 - Read the vector in the database
    vectors = read_local_vectors()
    # 将输入文本转化为向量 - Convert the input text into a vector
    input_vector = get_vector(text)
    # 将输入文本转化为numpy数组 - Convert the input text into a numpy array
    input_vector = input_vector.detach().numpy()
    # 将数据库中的向量转化为numpy数组 - Convert the vector in the database to a numpy array
    vectors = vectors.detach().numpy()
    # 计算每个问题与输入文本的相似度 - Calculate the similarity between each question and the input text
    similarity = cosine_similarity(input_vector, vectors)
    print(similarity)
    similarity_sorted = np.squeeze(similarity)
    # 按照相似度从大到小排序 - Sort by similarity from large to small
    similarity_sorted = np.argsort(-similarity_sorted)
    if n > len(similarity):
        n = len(similarity)
    if len(similarity) > 0:
        # 取出相似度最高的前n个文本的序号 - Take out the serial number of the top n texts with the highest similarity
        knowledges_ids = similarity_sorted[:n].tolist()
        # 读取知识库 - Read the knowledge base
        with open('gushiwen.json', 'r', encoding='utf8') as file:
            file_content = file.read()
            knowledges = json.loads(file_content)

            # 去除概率小于阈值的知识 - Remove knowledge with probability less than threshold
            knowledges_ids = [i for i in knowledges_ids if similarity[0][i] > threshold]
            # 直接输出资料文本 - directly output the text
            knowledges = [str(knowledge) for knowledge in knowledges]
            # 取出相似度最高的前n个文本 - Take out the top n texts with the highest similarity
            knowledges = [knowledges[i] for i in knowledges_ids]
        return knowledges
    return ''

# sample:
if __name__ == '__main__':
    input_text = '李白的诗歌'
    knowledges = get_domain_knowledge(input_text, 5)
    print(knowledges, len(knowledges))

[[0.49482936 0.5112034  0.49520627 ... 0.45510006 0.47185707 0.53729707]]
['{\'id\': 4387, \'href\': \'/shiwenv_97dccf96451a.aspx\', \'title\': \'长歌续短歌\', \'author\': \'李贺\', \'dynasty\': \'唐代\', \'content\': \'<br/>                    长歌破衣襟，短歌断白发。秦王不可见，旦夕成内热。渴饮壶中酒，饥拔陇头粟。凄凉四月阑，千里一时绿。夜峰何离离，明月落石底。徘徊沿石寻，照出高峰外。不得与之游，歌成鬓先改。 <br/>                \', \'sons\': {\'译文及注释\': {\'content\': \'译文<br/>写长歌把我的衣襟磨破，吟短诗使我的白发脱落。<br/>谒见秦王没有机缘，日夜焦虑我心中烦热。<br/>喝口壶中酒，聊以解渴，拔把垅头谷，暂充饥饿。<br/>四月将尽，千里大地一片绿色，自己却贫困潦倒，不由人感到凄凉难过。<br/>夜幕中峰峦起伏重叠，明亮的月光却只向谷底照射。<br/>我来来回回沿着石崖寻觅，可它又在高峰之外不可捉摸。<br/>自己终不得与其共事，歌成而头发早已变白。<br/>注释<br/>长歌续短歌：题目从古乐府《长歌行》、《短歌行》化出。<br/>长歌二句：互文的修辞手法，长歌短歌，唱破衣襟，吟断白发。<br/>秦王：指唐宪宗。宪宗当时在秦地，所以称为秦王。<br/>旦夕：日日夜夜。内热：内心急躁而炽热。<br/>陇头：田间地头。此二句比喻诗人如饥似渴地思念唐宪宗。<br/>凄凉二句：因为困顿潦倒，看到初夏万物茂盛，更加自感凄凉。<br/>离离：重叠、罗列的样子。<br/>明月：比喻唐宪宗。这两句的意思为：夜峰罗列，月光照耀在落石下，不及他处。比喻君恩被群小阻隔。<br/>裴回：即“徘徊”，彷徨不进貌。<br/>之：代词，代指唐宪宗。<br/>鬓先改：鬓发已经变白。<br/>\', \'cankao\': \'<br/>参考资料：完善<br/><br/>1、<br/>冯浩非 徐传武．李贺诗选译．成都：巴蜀书社，1991：112-114<br/><br/>\'}, \'赏析\': 