In [1]:
import os
import sys
import asyncio
import numpy as np
from pathlib import Path
from dotenv import load_dotenv
from lightrag.utils import logger
# 添加父目录到路径，以便导入LightRAG包
# 在Jupyter Notebook中使用os.getcwd()代替__file__
notebook_dir = Path(os.getcwd())
# 假设notebook在tests目录下
parent_dir = notebook_dir.parent
sys.path.append(str(parent_dir))

from lightrag import LightRAG, QueryParam
from lightrag.utils import logger, EmbeddingFunc
from lightrag.llm.openai import (
    gpt_4o_mini_complete,
    gpt_4o_complete,
    openai_embed,
    openai_complete_if_cache
)
from lightrag.llm.ollama import ollama_model_complete, ollama_embed
from lightrag.llm.siliconcloud import siliconcloud_embedding
from lightrag.chunking import set_hierarchical_chunking_config, get_hierarchical_chunking_config


# 添加父目录到路径，以便导入LightRAG包
# 在Jupyter Notebook中使用os.getcwd()代替__file__
notebook_dir = Path(os.getcwd())
# 假设notebook在tests目录下
parent_dir = notebook_dir.parent
sys.path.append(str(parent_dir))

from IPython.display import Markdown, display

# Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
import nest_asyncio
nest_asyncio.apply()
%cd /Users/llp/opensource/LightRAG

WORKING_DIR = "data/TechnicalDemo"

os.makedirs(WORKING_DIR, exist_ok=True)

# 在导入部分添加
from dotenv import load_dotenv

# 在代码的开头加载环境变量
load_dotenv()

async def llm_model_func(
    prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
    return await openai_complete_if_cache(
        "Qwen/Qwen2.5-14B-Instruct",
        
        # "Qwen/Qwen2.5-32B-Instruct",
        # "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
        # "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B",
        prompt,
        system_prompt=system_prompt,
        history_messages=history_messages,
        api_key=os.getenv("SILICONFLOW_API_KEY"),
        base_url="https://api.siliconflow.cn/v1/",
        **kwargs,
    )


# async def embedding_func(texts: list[str]) -> np.ndarray:
#     return await siliconcloud_embedding(
#         texts,
#         model="netease-youdao/bce-embedding-base_v1",
#         api_key=os.getenv("SILICONFLOW_API_KEY"),
#         max_token_size=512,
#     )

async def embedding_func(texts: list[str]) -> np.ndarray:
    # 添加批处理大小限制，最大批大小为32
    max_batch_size = 32
    
    # 如果输入小于或等于最大批大小，直接处理
    if len(texts) <= max_batch_size:
        return await siliconcloud_embedding(
            texts,
            model="netease-youdao/bce-embedding-base_v1",
            api_key=os.getenv("SILICONFLOW_API_KEY"),
            max_token_size=512,
        )
    
    # 如果输入超过最大批大小，分批处理
    all_embeddings = []
    for i in range(0, len(texts), max_batch_size):
        batch_texts = texts[i:i+max_batch_size]
        batch_embeddings = await siliconcloud_embedding(
            batch_texts,
            model="BAAI/bge-m3",
            api_key=os.getenv("SILICONFLOW_API_KEY"),
            max_token_size=512,
        )
        all_embeddings.append(batch_embeddings)
    
    # 合并所有批次的结果
    return np.vstack(all_embeddings)

technical_doc = open(os.path.join("tests", "technical_manual.md")).read()


Looking in indexes: https://mirrors.cloud.tencent.com/pypi/simple
/Users/llp/opensource/LightRAG


[31mERROR: Could not find a version that satisfies the requirement lmdeploy (from versions: none)[0m[31m
[0m[31mERROR: No matching distribution found for lmdeploy[0m[31m
[0m

In [2]:
from lightrag.base import ChunkingMode
from lightrag.kg.shared_storage import initialize_pipeline_status


async def get_embedding_dim():
    test_embedding = await embedding_func(["test"])
    return len(test_embedding[0])

async def initialize_rag():
    # Detect embedding dimension
    embedding_dimension = await get_embedding_dim()
    print(f"Detected embedding dimension: {embedding_dimension}")

    # Initialize LightRAG with domain
    rag = LightRAG(
        working_dir=WORKING_DIR,
        llm_model_func=llm_model_func,
        embedding_func=EmbeddingFunc(
            embedding_dim=embedding_dimension,
            max_token_size=8192,
            func=embedding_func,
        ),
        # 初始化时可以设置domain
        domain="technical_manual",
    )

    technical_manual_config = {
        "language": "中文",
        "entity_types": ["部件", "故障原因", "故障现象", "维修步骤"],
        "tuple_delimiter": "<|>",
        "record_delimiter": "##",
        "completion_delimiter": "<|COMPLETE|>"
    }
    set_hierarchical_chunking_config(
        heading_levels=3,        # 处理到 #### 级别标题
        parent_level=2,          # ### 级别标题作为父文档
        preprocess_attachments=False,  # 不预处理附件标题
    )
    rag.register_domain("technical_manual", technical_manual_config)

    rag.chunking_mode = ChunkingMode.HIREARCHIACL
    # 导入并初始化pipeline状态
    await initialize_pipeline_status()
 
    # 初始化存储
    await rag.initialize_storages()
    
    return rag

In [3]:
rag = await initialize_rag()
await rag.ainsert(technical_doc)

INFO: Process 13161 Shared-Data created for Single Process
INFO: Process 13161 initializing storage statuses


Detected embedding dimension: 768


INFO:nano-vectordb:Init {'embedding_dim': 768, 'metric': 'cosine', 'storage_file': 'data/TechnicalDemo/vdb_entities.json'} 0 data
INFO:nano-vectordb:Init {'embedding_dim': 768, 'metric': 'cosine', 'storage_file': 'data/TechnicalDemo/vdb_relationships.json'} 0 data
INFO:nano-vectordb:Init {'embedding_dim': 768, 'metric': 'cosine', 'storage_file': 'data/TechnicalDemo/vdb_chunks.json'} 0 data
INFO: Process 13161 Pipeline namespace initialized
INFO: Process 13161 initialized updated flags for namespace: [full_docs]
INFO: Process 13161 ready to initialize storage namespace: [full_docs]
INFO: Process 13161 initialized updated flags for namespace: [text_chunks]
INFO: Process 13161 ready to initialize storage namespace: [text_chunks]
INFO: Process 13161 initialized updated flags for namespace: [entities]
INFO: Process 13161 initialized updated flags for namespace: [relationships]
INFO: Process 13161 initialized updated flags for namespace: [chunks]
INFO: Process 13161 initialized updated flags

In [None]:
# 模糊查询函数
def find_closest_entity(query, entities):
    """查找最接近的实体"""
    matches = []
    for entity in entities:
        # 移除引号和大写处理
        clean_entity = entity.replace('"', '')
        if query.lower() in clean_entity.lower():
            matches.append(entity)
    return matches

# 查看已提取的所有实体
all_nodes = list(rag.chunk_entity_relation_graph._graph.nodes())
print("知识图谱中的所有实体:")
for node in all_nodes:
    print(node)
# 使用模糊匹配
search_term = "空调噪音"
possible_matches = find_closest_entity(search_term, all_nodes)

if possible_matches:
    print(f"找到可能匹配的实体: {possible_matches}")
    tech_entity = possible_matches[0].replace('"', '')  # 使用第一个匹配项
    print(f"\n=== 查询技术实体: {tech_entity} ===")
    tech_entity_info = await rag.get_entity_info(tech_entity)
    # ... 其余代码保持不变
else:
    print(f"未找到与 '{search_term}' 匹配的实体")

In [None]:
# 查询实体和关系
# 查询实体和关系
medical_entity = "糖尿病"
tech_entity = "CRH380A"

# 查询医疗实体
print(f"\n=== 查询医疗实体: {medical_entity} ===")
medical_entity_info = await rag.get_entity_info(medical_entity)

if medical_entity_info["graph_data"]:
    print(f"类型: {medical_entity_info['graph_data'].get('entity_type', 'N/A')}")
    print(f"描述: {medical_entity_info['graph_data'].get('description', 'N/A')}")
    
    # 使用正确的方法获取相关节点和边
    # 格式化实体名称
    formatted_entity = f'"{medical_entity.upper()}"'
    
    # 获取所有出边（从该实体指向其他实体的关系）
    try:
        # 获取与该实体相连的所有边
        neighbors = list(rag.chunk_entity_relation_graph._graph.neighbors(formatted_entity))
        
        if neighbors:
            print("\n相关关系:")
            for neighbor in neighbors:
                # 获取边的属性
                edge_data = rag.chunk_entity_relation_graph._graph.get_edge_data(formatted_entity, neighbor)
                if edge_data:
                    print(f"- 目标实体: {neighbor}")
                    print(f"  描述: {edge_data.get('description', 'N/A')}")
                    print(f"  关键词: {edge_data.get('keywords', 'N/A')}")
                    print(f"  权重: {edge_data.get('weight', 'N/A')}")
        else:
            print("没有找到相关关系")
            
    except Exception as e:
        print(f"获取关系时出错: {e}")
else:
    print(f"未找到实体: {medical_entity}")

# 查询技术实体    
print(f"\n=== 查询技术实体: {tech_entity} ===")
tech_entity_info = await rag.get_entity_info(tech_entity)

if tech_entity_info["graph_data"]:
    print(f"类型: {tech_entity_info['graph_data'].get('entity_type', 'N/A')}")
    print(f"描述: {tech_entity_info['graph_data'].get('description', 'N/A')}")
    
    # 使用正确的方法获取相关节点和边
    # 格式化实体名称
    formatted_entity = f'"{tech_entity.upper()}"'
    
    # 获取所有出边（从该实体指向其他实体的关系）
    try:
        # 获取与该实体相连的所有边
        neighbors = list(rag.chunk_entity_relation_graph._graph.neighbors(formatted_entity))
        
        if neighbors:
            print("\n相关关系:")
            for neighbor in neighbors:
                # 获取边的属性
                edge_data = rag.chunk_entity_relation_graph._graph.get_edge_data(formatted_entity, neighbor)
                if edge_data:
                    print(f"- 目标实体: {neighbor}")
                    print(f"  描述: {edge_data.get('description', 'N/A')}")
                    print(f"  关键词: {edge_data.get('keywords', 'N/A')}")
                    print(f"  权重: {edge_data.get('weight', 'N/A')}")
        else:
            print("没有找到相关关系")
            
    except Exception as e:
        print(f"获取关系时出错: {e}")
else:
    print(f"未找到实体: {tech_entity}")

In [None]:
# param (QueryParam): 查询执行的配置参数。可通过param.mode指定查询模式:
#     - naive: 基础查询模式,直接从文本块中检索相关内容
#     - hierarchical: 层次化查询模式,考虑文档的层次结构关系进行检索
#     - local: 基于本地知识图谱的查询模式
#     - global: 基于全局知识图谱的查询模式  
#     - hybrid: 混合查询模式,同时使用文本检索和知识图谱


tech_manual_query_param = QueryParam(
    mode="naive",  # 混合知识图谱和向量查询
    top_k=5,  # 检索更少但更精确的结果
    disable_response_cache=True,
)

tech_system_prompt = """你是一个专业的技术维修顾问。
回答问题时请注意：
1. 严格遵循技术参数（电压、压力、电流等数值）
2. 清晰区分不同设备型号的特定要求
3. 按正确顺序列出维修步骤
4. 引用相应的手册章节
请仅基于提供的文档内容回答，不要臆测信息。"""

response = rag.query(
    "车门状态指示灯 (例如：开门指示灯、关门指示灯、故障指示灯) 显示错误或不亮", 
    param=tech_manual_query_param,
    # system_prompt=tech_system_prompt
)
display(Markdown(response))


In [None]:
tech_manual_query_param = QueryParam(
    mode="hierarchical",  # 混合知识图谱和向量查询
    top_k=5,  # 检索更少但更精确的结果
    disable_response_cache=True,
)

response = rag.query(
    "车门状态指示灯 (例如：开门指示灯、关门指示灯、故障指示灯) 显示错误或不亮", 
    param=tech_manual_query_param,
    # system_prompt=tech_system_prompt
)
display(Markdown(response))



In [None]:
tech_manual_query_param = QueryParam(
    mode="local",  # 混合知识图谱和向量查询
    top_k=5,  # 检索更少但更精确的结果
    disable_response_cache=True,
)

response = rag.query(
    "出现噪声过大时的处理方法", 
    param=tech_manual_query_param,
    # system_prompt=tech_system_prompt
)
display(Markdown(response))


In [None]:
tech_manual_query_param = QueryParam(
    mode="global",  # 混合知识图谱和向量查询
    top_k=5,  # 检索更少但更精确的结果
    disable_response_cache=True,
)

response = rag.query(
    "出现噪声过大时的处理方法", 
    param=tech_manual_query_param,
    # system_prompt=tech_system_prompt
)
display(Markdown(response))


In [None]:
tech_manual_query_param = QueryParam(
    mode="mix",  # 混合知识图谱和向量查询
    top_k=5,  # 检索更少但更精确的结果
    disable_response_cache=True,
)

response = rag.query(
    "出现噪声过大时的处理方法", 
    param=tech_manual_query_param,
    # system_prompt=tech_system_prompt
)
display(Markdown(response))
