In [3]:
import re
import networkx as nx
from graspologic.partition import hierarchical_leiden
from collections import defaultdict

from llama_index.core.llms import ChatMessage
from llama_index.graph_stores.neo4j import Neo4jPropertyGraphStore


from llama_index.core.query_engine import CustomQueryEngine
from llama_index.core import PropertyGraphIndex

  from .autonotebook import tqdm as notebook_tqdm


### GraphRAGStore

In [4]:
class GraphRAGStore(Neo4jPropertyGraphStore):
    community_summary = {} # it may not need, prob
    
    def _create_nx_graph(self):
        """Convert interal graph representation to NetworkX graph"""

        nx_graph = nx.Graph()
        triplets = self.get_triplets()

        for entity1, relation, entity2 in triplets:
            nx_graph.add_node(entity1.name)
            nx_graph.add_node(entity2.name)
            nx_graph.add_edge(
                relation.source_id,
                relation.target_id,
                relationship=relation.label,
                # description=relation.properties['relationship_description']
            )

        return nx_graph
    

    def _collect_community_info(self, nx_graph, clusters):
        """
        Collect information for each node based on their community,
        allowing entities to belong to multiple clusters.
        """
        entity_info = defaultdict(set) # for avoiding duplicate (node_id, cluster_id) composition
        community_info = defaultdict(list)

        for item in clusters:
            node = item.node
            cluster_id = item.cluster

            # Update entity_info by add cluster_id as value, node as key
            entity_info[node].add(cluster_id)

            for neighbor in nx_graph.neighbors(node):
                edge_data = nx_graph.get_edge_data(node, neighbor)
                if edge_data:
                    detail = f"{node} -> {neighbor} -> {edge_data['relationship']}"
                    #  detail = f"{node} -> {neighbor} -> {edge_data['relationship']} -> {edge_data['description']}"
                    community_info[cluster_id].append(detail)

        # convert sets to lists for easier serialization if needed
        entity_info = {k : list(v) for k, v in entity_info.items()}

        # convet to normal dictionary
        return dict(entity_info), dict(community_info) 

In [5]:
from collections import defaultdict

# Create a defaultdict with set as the default factory
entity_info = defaultdict(set)

# Sample node-cluster data
data = [
    ("node_1", "cluster_A"),
    ("node_1", "cluster_B"),
    ("node_2", "cluster_A"),
    ("node_1", "cluster_A"),  # Duplicate cluster for node_1
    ("node_3", "cluster_C"),
    ("node_2", "cluster_B"),
    ("node_1", "cluster_B"),  # Duplicate cluster for node_1
]

# Adding nodes and cluster associations
for node, cluster in data:
    entity_info[node].add(cluster)

# Convert to dictionary for better visualization
entity_info_dict = {k: list(v) for k, v in entity_info.items()}

# Output the result
print("Entity Info (defaultdict):")
print(entity_info)
print("\nEntity Info (as dictionary with lists):")
print(entity_info_dict)


Entity Info (defaultdict):
defaultdict(<class 'set'>, {'node_1': {'cluster_B', 'cluster_A'}, 'node_2': {'cluster_B', 'cluster_A'}, 'node_3': {'cluster_C'}})

Entity Info (as dictionary with lists):
{'node_1': ['cluster_B', 'cluster_A'], 'node_2': ['cluster_B', 'cluster_A'], 'node_3': ['cluster_C']}


In [6]:
entity_info

defaultdict(set,
            {'node_1': {'cluster_A', 'cluster_B'},
             'node_2': {'cluster_A', 'cluster_B'},
             'node_3': {'cluster_C'}})

In [7]:
entity_info == dict(entity_info)

True

### GraphRAGQueryEngine

In [8]:
class GraphRAGQueryEngine(CustomQueryEngine):
    graph_store: GraphRAGStore # for getting summary
    index: PropertyGraphIndex # iniciate it needs PropertyGraphIndex type

    similarity_top_k: int = 20

    def get_entities(self, query_str, similarity_top_k):
        nodes_retrived = self.index.as_retriever(
            similarity_top_k=similarity_top_k
        ).retrieve(query_str)

        entities = set()

        pattern = (
            r"^(\w+(?:\s+\w+)*)\s*->\s*([a-zA-Z\s]+?)\s*->\s*(\w+(?:\s+\w+)*)$"
        )

        for node in nodes_retrived:
            matches = re.findall(
                pattern, node.text, re.MULTILINE | re.IGNORECASE
            )

            for match in matches:
                subject = match[0]
                object = match[2]
                entities.add(subject)
                entities.add(object)


        return list(entities)

### Pipeline

In [10]:
graph_store = GraphRAGStore(
    username='neo4j',
    password='yfy12345',
    url='bolt://localhost:7687'
)

In [11]:
graph_store

<__main__.GraphRAGStore at 0x2976746fbb0>