In [None]:
import networkx as nx
from community import community_louvain
from openai import OpenAI
from pyvis.network import Network
import json
from collections import  defaultdict
from dotenv import load_dotenv
import os
# 初始化知识图谱组件
from OmniStore.chromadb_store import StoreTool
from sentence_transformers import SentenceTransformer
from KnowledgeGraphManager.KGManager import KgManager
import leidenalg
from igraph import Graph as IGraph

load_dotenv(dotenv_path="./.env")


device = os.getenv("DEVICE")


if os.getenv("IS_USE_LOCAL") == "True":
    embeddings = SentenceTransformer(
        os.getenv("EMBEDDINGS_PATH")
    ).to(device)
else:
    # 初始化模型和组件
    embeddings = SentenceTransformer(os.getenv("EMBEDDINGS")).to(device)


# 创建两个独立的存储工具
chromadb_store = StoreTool(storage_path= os.getenv("CHROMADB_PATH"), embedding_function=embeddings)

client = OpenAI(
    api_key=os.getenv("API_KEY"),
    base_url=os.getenv("BASE_URL")
)

# 多模态模型
vl_client = OpenAI(
    # 若没有配置环境变量，请用百炼API Key将下行替换为：api_key="sk-xxx"
    api_key=os.getenv("VL_API_KEY"),
    base_url=os.getenv("VL_BASE_URL")
)
from LLM.Openai_Agent import OpenaiAgent
# 创建两个独立的agent
rag_agent = OpenaiAgent(client)
kg_agent = OpenaiAgent(client)

# 创建两个独立的splitter
simple_files = os.getenv("SIMPLE", "").split(",")
semantic_files = os.getenv("SEMANTIC", "").split(",")
character_files = os.getenv("CHARACTER", "").split(",")

# 初始化默认分割器
kg_splitter = None

# 创建默认分割器
if len(simple_files) > 0:
    from TextSlicer.SimpleTextSplitter import SimpleTextSplitter
    kg_splitter = SimpleTextSplitter(2045, 1024)
elif len(semantic_files) > 0:
    from TextSlicer.SemanticTextSplitter import SemanticTextSplitter
    kg_splitter = SemanticTextSplitter(2045, 1024)
elif len(character_files) > 0:
    from TextSlicer.CharacterTextSplitter import CharacterTextSplitter
    kg_splitter = CharacterTextSplitter(separator="</end>", keep_separator=False, max_tokens=2045, min_tokens=1024)

# 创建两个独立的kg_manager
kg_manager = KgManager(agent=kg_agent, splitter=kg_splitter, embedding_model=embeddings, store=chromadb_store)

kg_manager.load_store("知识融合")


G = kg_manager.current_G
G2 = kg_manager.current_G
print(G)

# ================== 社区划分算法部分 ==================
# Louvain算法
community_map = community_louvain.best_partition(G2.to_undirected())
print(community_map, 2)

# Leiden算法
# 将networkx图转换为igraph图
nodes = list(G2.nodes())
node_id_map = {node: i for i, node in enumerate(nodes)}
edges = [(node_id_map[u], node_id_map[v]) for u, v in G2.edges()]

# 创建igraph图对象
igraph_G = IGraph(directed=False)
igraph_G.add_vertices(len(nodes))
igraph_G.add_edges(edges)

# 进行Leiden社区检测
partition = leidenalg.find_partition(
    igraph_G,
    leidenalg.ModularityVertexPartition,
    n_iterations=-1  # 使用无限迭代直到收敛
)
leiden_communities = partition.membership
leiden_community_map = {node: leiden_communities[node_id_map[node]] for node in nodes}


# ================== 可视化部分 ==================
def visualize_communities(graph, community_dict, filename):
    """通用社区可视化函数"""
    net = Network(height="600px", width="100%", notebook=True)

    # 添加节点
    for node in graph.nodes():
        net.add_node(
            node,
            label=str(node),
            group=community_dict[node],
            color="#%06x" % (community_dict[node] * 0x0F0F0F)
        )

    # 添加边
    for edge in graph.edges():
        net.add_edge(edge[0], edge[1])

    # 配置物理引擎
    net.toggle_physics(True)
    net.show(filename)


class CommunityRetriever:
    def __init__(self, graph, community_map):
        """
        :param graph: NetworkX 图对象（带权重）
        :param community_map: 社区划分字典 {节点: 社区ID}
        """
        self.graph = graph
        self.community_map = community_map
        self._build_community_index()

    def _build_community_index(self):
        """构建社区反向索引"""
        self.community_index = defaultdict(list)
        for node, comm_id in self.community_map.items():
            self.community_index[comm_id].append(node)

    def find_related_entities(self, query_entities, top_n=20):
        """
        带权重的社区实体检索
        :param query_entities: 查询实体列表
        :param top_n: 返回每个实体的相关实体数量
        :return: 排序后的相关实体字典
        """
        results = {}

        for entity in query_entities:
            if entity not in self.graph:
                continue

            # 获取实体所属社区
            comm_id = self.community_map.get(entity, -1)
            community_nodes = self.community_index.get(comm_id, [])

            # 计算相关性分数
            scores = []
            for node in community_nodes:
                if node == entity:
                    continue

                # 计算路径权重（考虑直接连接和多跳关系）
                try:
                    # 获取最短路径
                    path = nx.shortest_path(self.graph, source=entity, target=node)
                    # 计算路径权重总和
                    path_weight = sum(
                        self.graph[path[i]][path[i + 1]].get('weight', 1)
                        for i in range(len(path) - 1)
                    )
                    # 标准化处理：路径权重均值
                    score = path_weight / len(path)
                except nx.NetworkXNoPath:
                    score = 0

                scores.append((node, score))

            # 按分数降序排序
            sorted_entities = sorted(scores, key=lambda x: x[1], reverse=True)
            results[entity] = {
                'community': comm_id,
                'related': sorted_entities[:top_n]
            }

        return results


# 生成Louvain可视化
visualize_communities(G, community_map, "louvain_community.html")

# 生成Leiden可视化
visualize_communities(G, leiden_community_map, "leiden_community.html")

# 初始化检索器（使用Louvain社区划分）
louvain_retriever = CommunityRetriever(G.to_undirected(), community_map)

# 初始化检索器（使用Leiden社区划分）
leiden_retriever = CommunityRetriever(G.to_undirected(), leiden_community_map)

# 示例使用
if __name__ == "__main__":
    # 假设用户问题中的实体列表
    query_entities = ["牛顿力学", "星海科技","经典物理"]

    # 使用Louvain算法的检索结果
    louvain_result = louvain_retriever.find_related_entities(query_entities)
    print("Louvain社区检索结果：")
    print(json.dumps(louvain_result, indent=5, ensure_ascii=False))

    # 使用Leiden算法的检索结果
    leiden_result = leiden_retriever.find_related_entities(query_entities)
    print("\nLeiden社区检索结果：")
    print(json.dumps(leiden_result, indent=5, ensure_ascii=False))