In [None]:
param1= "../database/resources/xxxx.pdf"

## 从 neo4j 中获得编号 EC_R_

In [None]:
from neo4j import GraphDatabase

# 连接Neo4j数据库  
uri = "bolt://localhost:7687"  # Neo4j数据库URI
username = "neo4j"  # Neo4j数据库用户名
password = "xxx"  # Neo4j数据库密码

driver = GraphDatabase.driver(uri, auth=(username, password))

In [None]:
def get_node_count_with_prefix(driver, prefix):
    """ 计算以特定前缀开头的节点数量 """
    def count_nodes_with_prefix(tx, prefix):
        # 执行Cypher查询，计算以特定前缀开头的节点数量  
        result = tx.run("MATCH (n) WHERE n.id STARTS WITH $prefix RETURN count(n)", prefix=prefix)  
        return result.single()[0]  

    # 使用driver执行查询  
    with driver.session() as session:  
        count = session.read_transaction(count_nodes_with_prefix, prefix)  
    return count  

In [None]:
prefix = "EC_R"  # 查询前缀

# 调用函数并打印结果
node_count = get_node_count_with_prefix(driver, prefix)

print(f"以'{prefix}'开头的实体节点数量为: {node_count}")

# 关闭driver连接
driver.close()

In [None]:
from pymongo import MongoClient
import gridfs
from gridfs import GridFS

# 连接到 MongoDB
client = MongoClient('mongodb://localhost:27017/')
db = client['mydatabase']  # 访问数据库

# 创建GridFSBucket实例
bucket = gridfs.GridFSBucket(db)

In [None]:
import os
import time
import json
from dotenv import load_dotenv, find_dotenv


# 读取本地/项目的环境变量
_ = load_dotenv(find_dotenv())

# pip install zhipuai 请先在终端进行安装
import os
from zhipuai import ZhipuAI
client = ZhipuAI(api_key=os.environ["ZHIPUAI_API_KEY"])

def get_completion_glm_flash(prompt, model="glm-4-flash", temperature_=0.01, max_tokens_=4095):

    response = client.chat.completions.create(
        model="glm-4-flash",
        messages=[
#             {
#                 "role": "system",
#                 "content": "你是一个乐于解答各种问题的助手，你的任务是为用户提供专业、准确、有见地的建议。" 
#             },
            {
                "role": "user",
                "content": "%s" % (prompt)
            }
        ],
        top_p= 0.7,
        temperature= temperature_,
        max_tokens=max_tokens_,
        tools = [{"type":"web_search","web_search":{"search_result":False}}],
        stream=False
    )
    if len(response.choices) > 0:
        return response.choices[0].message.content
    return "generate answer error"

In [None]:
import os
from pptx import Presentation
import pandas as pd
from pdf2image import convert_from_path
from docx import Document
import shutil
import requests
import base64


def handle_other(file_path):
    pass

# 定义处理不同文件类型的函数
def handle_pptx(file_path):
    '''获取 ppt 文件中所有文本'''
    presentation = Presentation(file_path)
    text_runs = []
    for slide in presentation.slides:
        for shape in slide.shapes:
            if hasattr(shape, "text"):
                text_runs.append(shape.text)
    return '\n'.join(text_runs)


def handle_docx(file_path):
    """
    输入：docx文件路径
    输出：docx文件中所有文本
    """
    # 初始化一个空字符串，用于存储文档中的所有文本
    full_text = []
    # 加载Word文档
    doc = Document(file_path)
    # 遍历文档中的每个段落
    for para in doc.paragraphs:
        # 将每个段落的文本添加到 full_text 列表中
        full_text.append(para.text)
    # 将列表中的文本合并为一个单一的字符串
    full_text_string = '\n'.join(full_text)
    # 打印提取的文本
    return full_text_string

# 下面是处理 pdf 的代码，我觉得可以考虑使用 pymupdf 来读取 pdf 文件内容
# from langchain.document_loaders.pdf import PyMuPDFLoader 解析 PDF
# https://datawhalechina.github.io/llm-universe/#/C3/3.%E6%95%B0%E6%8D%AE%E5%A4%84%E7%90%86

def ocr_images(file_path):
    # 图片转base64
    with open(file_path, "rb") as f:
        base64_data = base64.b64encode(f.read()).decode()
    url = "http://127.0.0.1:1224/api/ocr"
    data = {
        "base64": base64_data,
        # 可选参数
        "options": {
            # 通用参数
            "tbpu.parser": "single_para",
            "data.format": "text",
            # 引擎参数
            "ocr.cls": False,
            "ocr.language": "models/config_chinese.txt",
            "ocr.maxSideLen": 1024
        }
    }
    headers = {"Content-Type": "application/json"}
    data_str = json.dumps(data)
    response = requests.post(url, data=data_str, headers=headers)
    if response.status_code == 200:
        res_dict = json.loads(response.text)
        #若"data.format": "txet",返回识别结果,表示去掉位置信息，仅输出识别结果
        return res_dict["data"]
        #若"data.format": "dict",返回全部识别结果，为字典，包含位置信息，匹配度等信息，循环输出识别结果
        # for i in res_dict["data"]:
        #     print(i["text"])

def handle_pdf(file_path):
    '''将 pdf 转为 图片，然后调用 umi-ocr http 识别模块，提取 pdf 中的文字
    输入：pdf文件路径
    输出：pdf中所有文本
    '''
    image_folder_path = '../database/pdf2img/'
    # 如果不存在，则创建图片输出文件夹
    if not os.path.exists(image_folder_path):
        os.makedirs(image_folder_path)
    pages = convert_from_path(file_path)
    
    # 使用 os.path.basename 获取文件名（包含扩展名）
    file_name_with_extension = os.path.basename(file_path)
    # 使用 os.path.splitext 分割文件名和扩展名
    file_name, file_extension = os.path.splitext(file_name_with_extension)
    
    for i, page in enumerate(pages):
        # 保存图片到指定文件夹
        img_path = os.path.join(image_folder_path, f"{file_name}_{i + 1:03d}.png")
        page.save(img_path, 'PNG')

    pdf_text = ''
    
    # 遍历文件夹中的所有文件
    for filename in os.listdir(image_folder_path):
        # 由于文件夹中都是PNG图片，直接构造完整的文件路径
        file_path = os.path.join(image_folder_path, filename)
        # 调用函数处理图片
        pdf_text = pdf_text + str(ocr_images(file_path))
    # 删除的文件夹
    shutil.rmtree(image_folder_path)
    return pdf_text

import re
import jieba
def process_text(stopwords_file_path:str, dictionary_file_path:str, text:str) -> list:
    """
    针对文本字符串的中文文本预处理：分词、去停用词、添加字典
    输入：停用词表路径、自定义字典路径、待处理的文本
    输出：已预处理的列表
    """

    # 正则表达式匹配所有非中文的字符
    pattern = re.compile(r'[^\u4e00-\u9fa5]')
    # 替换匹配到的字符为空字符串
    text_cn = re.sub(pattern, '', text)
    
    jieba.load_userdict(dictionary_file_path)  # 加载自定义词典
    result = []
    seg_list = jieba.cut(text_cn)
    for word in seg_list:
        result.append(word)
    
    # 导入停用词
    f = open(stopwords_file_path, 'r', encoding='utf-8')
    stop_word_list=[]
    for line in f.readlines():
        stop_word_list.append(line.strip('\n'))
    line_clean = []  # 去除停用词后的分词结果
    for word in result:
        if word in stop_word_list:
            continue
        line_clean.append(word)
    # 拼接字符
    line_clean = "".join(line_clean)
    return line_clean

根据文件类型选择合适的函数读取文件

In [None]:
import os

file_path = param1

base_name = os.path.basename(file_path)
filename_ = os.path.splitext(base_name)[0]

# # 使用 'with' 语句确保文件正确关闭
# with open(file_path, 'r', encoding='utf-8') as file:
#     content = file.read()

# # 输出文件内容
# print(content)

import os

def read_file(file_path):
    # 获取文件扩展名
    _, file_extension = os.path.splitext(file_path)
    
    # 根据文件扩展名调用相应的处理函数
    if file_extension.lower() == '.pdf':
        return handle_pdf(file_path)
    elif file_extension.lower() == '.docx':
        return handle_docx(file_path)
    elif file_extension.lower() == '.pptx':
        return handle_pptx(file_path)
    else:
        raise ValueError(f"Unsupported file type: {file_extension}")

content = read_file(file_path)

In [None]:
file_explanation = get_completion_glm_flash("请总结文本内容 %s，以纯文本格式输出，字数限在200字以内" % content)

In [None]:
file_explanation

## 生成文件的元数据

启动mongodb：`mongod --dbpath xxxxx`

In [None]:
metadata = {
    "node_id": f"EC_R_{(node_count + 1):04}",
    "name_zh": filename_,
    "explanation": file_explanation,
    'type': "resource",
    'timeZoneOffset': "+08:00" # 默认为 UTC 时区，时区偏移量 +8 为北京时间
}

# 读取文件内容到内存
with open(file_path, "rb") as file:
    file_data = file.read()

file_id = bucket.upload_from_stream(
    filename=filename_,
    source=file_data,
    metadata=metadata
)
print("File uploaded with metadata.")

## 获得 neo4j 中知识点实体的所有中文名

In [None]:
from neo4j import GraphDatabase

# 连接Neo4j数据库  
uri = "bolt://localhost:7687"  # Neo4j 数据库 URI
username = "neo4j"  # Neo4j数据库用户名
password = "xxx"  # Neo4j数据库密码

driver = GraphDatabase.driver(uri, auth=(username, password))

In [None]:
def fetch_knowledge_entities():  
    with driver.session() as session:  
        result = session.run("MATCH (n) WHERE n.type = 'knowledge' RETURN n.name_zh AS name_zh")  
        knowledge_entities = [record["name_zh"] for record in result]  
        return knowledge_entities

# 调用函数并打印结果
knowledge_names = fetch_knowledge_entities()
print(knowledge_names)

# 关闭数据库连接  
driver.close()

## RRF 算法排序

bm25 算法

In [None]:
from rank_bm25 import BM25Okapi

corpus = knowledge_names

tokenized_corpus = [doc.split(" ") for doc in corpus]
# tokenized_corpus
bm25 = BM25Okapi(tokenized_corpus)
# <rank_bm25.BM25Okapi at 0x1047881d0>

In [None]:
query = content  # 资料文件中的文本内容
tokenized_query = query.split(" ")

doc_scores = bm25.get_scores(tokenized_query)
doc_scores
# array([0.        , 0.93729472, 0.        ])

In [None]:
sorted_entity_names_bm25 = bm25.get_top_n(tokenized_query, corpus, n=100)
# ['It is quite windy in London']
sorted_entity_names_bm25

启动 docker 中的 milvus

In [None]:
import logging
logging.basicConfig(level=logging.INFO)  # 只显示 INFO 级别及以上的日志

from pymilvus import connections, utility, FieldSchema, CollectionSchema, DataType, Collection
from text2vec import SentenceModel
import json


def connect_to_milvus():
    try:
        connections.connect("default", host="localhost", port="19530")
        print("Connected to Milvus.")
    except Exception as e:
        print(f"Failed to connect to Milvus: {e}")
        raise
        
def text2vec_embedding(text):
    model = SentenceModel('shibing624/text2vec-base-chinese')
    embedding = model.encode(text)
    return embedding.tolist()

In [None]:
connect_to_milvus()

In [None]:
# get collection
collection_name = "entity_vec" 
collection = Collection(name=collection_name)  

In [None]:
def extract_entity_names(results):  
    entity_names = []

    # 遍历搜索结果，提取 entity_name 和 distance  
    for hits in results:  
        for hit in hits:  
            entity_name = hit.entity.get('entity_name')  
            distance = hit.distance  # 假设 distance 是直接在 hit 中  
            entity_names.append((entity_name, distance))  

    # 按照 distance 升序排序
    entity_names.sort(key=lambda x: x[1])  

    # 只提取 entity_name 组成的新列表  
    sorted_entity_names = [name for name, _ in entity_names]  
    return sorted_entity_names  

def search_and_query(collection, search_vectors, search_field, search_params):
    collection.load()  
    # Vector search  
    # 将输出字段修改为实际存在的字段，比如 'entity_name'  
    result = collection.search(search_vectors, search_field, search_params, limit=100, output_fields=["entity_name"])  
    # print_search_results(result, "Vector search results:")
    sorted_entity_names = extract_entity_names(result)
    return sorted_entity_names

In [None]:
query = content  
query_vector = text2vec_embedding(query)  

In [None]:
sorted_entity_names_milvus = search_and_query(collection, [query_vector], "embeddings", {"metric_type": "L2", "params": {"nprobe": 10}})
sorted_entity_names_milvus

In [None]:
def rrf_rank_fusion(list1, list2, k=60):  
    """  
    使用RRF算法融合两个排序结果  
    
    :param list1: 第一个排序结果列表，按优先级从高到低排列  
    :param list2: 第二个排序结果列表，按优先级从高到低排列  
    :param k: 参数，通常设定为60  
    :return: 融合后的排序结果列表  
    """  

    # 创建一个字典来存储每个文档的得分  
    score_map = {}  

    # 对于第一个列表  
    for rank, doc in enumerate(list1):  
        score_map[doc] = score_map.get(doc, 0) + (1 / (rank + 1))  

    # 对于第二个列表  
    for rank, doc in enumerate(list2):  
        score_map[doc] = score_map.get(doc, 0) + (1 / (rank + 1 + k))  # 考虑k参数  

    # 根据得分进行排序  
    merged_list = sorted(score_map.items(), key=lambda item: item[1], reverse=True)  

    # 返回排序结果的文档列表  
    return [doc for doc, score in merged_list]  

In [None]:
list1 = sorted_entity_names_milvus
list2 = sorted_entity_names_bm25

merged_result = rrf_rank_fusion(list1, list2)  
print("合并后的排序结果:", merged_result) 

In [None]:
def get_file_metadata(filename_):
    metadata_neo4j = {}
    grid_out_cursor  = bucket.find({'filename': filename_})
    for file_metadata in grid_out_cursor:
        # file_metadata.filename
        metadata_neo4j['node_id'] = file_metadata.metadata['node_id']
        metadata_neo4j['name_zh'] = file_metadata.metadata['name_zh']
        metadata_neo4j['explanation'] = file_metadata.metadata['explanation']
        metadata_neo4j['type'] = file_metadata.metadata['type']
        metadata_neo4j['file_id'] = file_metadata._id
    return metadata_neo4j

In [None]:
metadata_neo4j = get_file_metadata(filename_)
metadata_neo4j

In [None]:
def get_node_ids(driver, keywords):  
    with driver.session() as session:  
        result = session.run(  
            "MATCH (n) WHERE n.name_zh IN $keywords RETURN n.id AS node_id",  
            keywords=keywords  
        )  
        return [record["node_id"] for record in result]  

In [None]:
# 查询并打印 node_id  
node_ids = get_node_ids(driver, merged_result[0:20])  
print(node_ids)

In [None]:
import json  

def relate_res_know(node_resource_id, node_ids):  
    """
    输入：资料实体的id、相关的知识点实体的id
    输出：list类型（可写入json文件）的[{"A_id":"资料实体id", "B_id":"知识点实体id", "type":"relate"}, {}]
    """
    # 构建包含输入数据的字典列表  
    data = []  
    for node_id in node_ids:  
        item = {  
            "A_id": node_resource_id,  
            "B_id": node_id,  
            "type": "relate"  
        }
        data.append(item)
    return data

In [None]:
node_resource_id = metadata_neo4j['node_id']
node_ids = node_ids

data_relate = relate_res_know(node_resource_id, node_ids)
data_relate

In [None]:
# 包含节点和关系的两类三元组，list类型
res_data = [metadata_neo4j] + data_relate 
res_data

## 在 neo4j 中添加资料实体

In [None]:
def create_graph(data):
    """
    输入：list类型的实体三元组、关系三元组
    """
    # 使用Neo4j的Python驱动程序创建会话并写入数据
    with driver.session() as session:
        for item in data:
            if item['type'] not in ('include', 'relate', 'order'):
                # 执行创建节点的操作
                try:
                    id = item['node_id']
                except KeyError:
                    print(item)
                type = item['type']
                try:
                    name_zh = item['name_zh']
                except KeyError:
                    print(item)
                name_en = item.get('name_en', "")  # 使用get方法避免name_en不存在时引发错误
                try:
                    explanation = item['explanation']
                except KeyError:
                    print(item)
                # updatetime = item['updatetime']
                # 构建Cypher语句创建节点，并设置属性
                session.run("CREATE (n:Item {id: $id, type: $type, name_zh: $name_zh, name_en: $name_en, explanation: $explanation})",
                                id=id, type=type, name_zh=name_zh, name_en=name_en, explanation=explanation)
                # session.run("CREATE (n:Item {id: $id, type: $type, name: $name, name_en: $name_en, explanation: $explanation, updatetime: $updatetime})",
                       # id=id, type=type, name=name, name_en=name_en, explanation=explanation, updatetime=updatetime)
    
            else:
                result_a = session.run("MATCH (n) WHERE n.id = $A_id RETURN n", A_id=item['A_id'])
                result_b = session.run("MATCH (n) WHERE n.id = $B_id RETURN n", B_id=item['B_id'])
                # 检查找到的节点是否存在
                if result_a.single() and result_b.single():
                    # 如果两个节点都存在，创建关系
                    relationship_type = item['type']
                    A_id = item['A_id']
                    B_id = item['B_id']
                    # 构建Cypher语句，并将关系类型作为字符串的一部分
                    cypher_query = (
                        f"MATCH (a), (b) WHERE a.id = $A_id AND b.id = $B_id "
                        f"CREATE (a)-[:{relationship_type}]->(b)"
                    )
                    session.run(cypher_query, A_id=A_id, B_id=B_id)
                else:
                    print(item['A_id'] + item['B_id'] + item['type'] + "One or both nodes do not exist.")
                    print(result_a.single())
                    print(result_b.single())
    
    # 关闭Neo4j驱动程序
    driver.close()

In [None]:
create_graph(res_data)