## Step 1: Install Dependencies

In [None]:
!pip install llama-index graspologic numpy==1.24.4 scipy==1.12.0 

## Step 2: Load and Preprocess Data

Load sample news data, which will be chunked into smaller parts for easier processing. For demonstration, we limit it to 50 samples. Each row (title and text) is converted into a Document object.

In [1]:
import pandas as pd
from llama_index.core import Document

In [2]:
# Load sample dataset
news = pd.read_csv("https://raw.githubusercontent.com/tomasonjo/blog-datasets/main/news_articles.csv")[:50]

In [3]:
# Convert data into LlamaIndex Document objects
documents = [
    Document(text=f"{row['title']}: {row['text']}")
    for _, row in news.iterrows()
]

## Step 3: Split Text into Nodes

Use SentenceSplitter to break down documents into manageable chunks.

In [4]:
from llama_index.core.node_parser import SentenceSplitter

splitter = SentenceSplitter(
    chunk_size=1024,
    chunk_overlap=20,
)
nodes = splitter.get_nodes_from_documents(documents)

chunk_overlap=20: Ensures chunks overlap slightly to avoid missing information at the boundaries

## Step 4: Configure the LLM, Prompt, and GraphRAG Extractor

Set up the LLM (e.g., GPT-4). This LLM will later analyze the chunks to extract entities and relationships.

In [5]:
from llama_index.llms.openai import OpenAI
import os

# os.environ["OPENAI_API_KEY"] = "your_openai_api_key"
llm = OpenAI(model="gpt-4")

The GraphRAGExtractor uses the above LLM, a prompt template to guide the extraction process, and a parsing function to process the LLM’s output into structured data. Text chunks (called nodes) are fed into the extractor. For each chunk, the extractor sends the text to the LLM along with the prompt, which instructs the LLM to identify entities, their types, and their relationships. The response is parsed by a function (parse_fn), which extracts the entities and relationships. These are then converted into EntityNode objects (for entities) and Relation objects (for relationships), with descriptions stored as metadata. The extracted entities and relationships are saved into the text chunk’s metadata, ready for use in building knowledge graphs or performing queries.

In [7]:
import asyncio
import nest_asyncio
 
nest_asyncio.apply()
 
from typing import Any, List, Callable, Optional, Union, Dict
from IPython.display import Markdown, display
 
from llama_index.core.async_utils import run_jobs
from llama_index.core.indices.property_graph.utils import (
    default_parse_triplets_fn,
)
from llama_index.core.graph_stores.types import (
    EntityNode,
    KG_NODES_KEY,
    KG_RELATIONS_KEY,
    Relation,
)
from llama_index.core.llms.llm import LLM
from llama_index.core.prompts import PromptTemplate
from llama_index.core.prompts.default_prompts import (
    DEFAULT_KG_TRIPLET_EXTRACT_PROMPT,
)
from llama_index.core.schema import TransformComponent, BaseNode
from llama_index.core.bridge.pydantic import BaseModel, Field
import re

In [8]:
entity_pattern = r'entity_name:\s*(.+?)\s*entity_type:\s*(.+?)\s*entity_description:\s*(.+?)\s*'
relationship_pattern = r'source_entity:\s*(.+?)\s*target_entity:\s*(.+?)\s*relation:\s*(.+?)\s*relationship_description:\s*(.+?)\s*'

def parse_fn(response_str: str) -> Any:
    entities = re.findall(entity_pattern, response_str)
    relationships = re.findall(relationship_pattern, response_str)
    return entities, relationships

In [9]:
class GraphRAGExtractor(TransformComponent):
    """Extract triples from a graph.
 
    Uses an LLM and a simple prompt + output parsing to extract paths (i.e. triples) and entity, relation descriptions from text.
 
    Args:
        llm (LLM):
            The language model to use.
        extract_prompt (Union[str, PromptTemplate]):
            The prompt to use for extracting triples.
        parse_fn (callable):
            A function to parse the output of the language model.
        num_workers (int):
            The number of workers to use for parallel processing.
        max_paths_per_chunk (int):
            The maximum number of paths to extract per chunk.
    """
 
    llm: LLM
    extract_prompt: PromptTemplate
    parse_fn: Callable
    num_workers: int
    max_paths_per_chunk: int
 
    def __init__(
        self,
        llm: Optional[LLM] = None,
        extract_prompt: Optional[Union[str, PromptTemplate]] = None,
        parse_fn: Callable = default_parse_triplets_fn,
        max_paths_per_chunk: int = 10,
        num_workers: int = 4,
    ) -> None:
        """Init params."""
        from llama_index.core import Settings
 
        if isinstance(extract_prompt, str):
            extract_prompt = PromptTemplate(extract_prompt)
 
        super().__init__(
            llm=llm or Settings.llm,
            extract_prompt=extract_prompt or DEFAULT_KG_TRIPLET_EXTRACT_PROMPT,
            parse_fn=parse_fn,
            num_workers=num_workers,
            max_paths_per_chunk=max_paths_per_chunk,
        )
 
    @classmethod
    def class_name(cls) -> str:
        return "GraphExtractor"
 
    def __call__(
        self, nodes: List[BaseNode], show_progress: bool = False, **kwargs: Any
    ) -> List[BaseNode]:
        """Extract triples from nodes."""
        return asyncio.run(
            self.acall(nodes, show_progress=show_progress, **kwargs)
        )
 
    async def _aextract(self, node: BaseNode) -> BaseNode:
        """Extract triples from a node."""
        assert hasattr(node, "text")
 
        text = node.get_content(metadata_mode="llm")
        try:
            llm_response = await self.llm.apredict(
                self.extract_prompt,
                text=text,
                max_knowledge_triplets=self.max_paths_per_chunk,
            )
            entities, entities_relationship = self.parse_fn(llm_response)
        except ValueError:
            entities = []
            entities_relationship = []
 
        existing_nodes = node.metadata.pop(KG_NODES_KEY, [])
        existing_relations = node.metadata.pop(KG_RELATIONS_KEY, [])
        metadata = node.metadata.copy()
        for entity, entity_type, description in entities:
            metadata[
                "entity_description"
            ] = description  # Not used in the current implementation. But will be useful in future work.
            entity_node = EntityNode(
                name=entity, label=entity_type, properties=metadata
            )
            existing_nodes.append(entity_node)
 
        metadata = node.metadata.copy()
        for triple in entities_relationship:
            subj, rel, obj, description = triple
            subj_node = EntityNode(name=subj, properties=metadata)
            obj_node = EntityNode(name=obj, properties=metadata)
            metadata["relationship_description"] = description
            rel_node = Relation(
                label=rel,
                source_id=subj_node.id,
                target_id=obj_node.id,
                properties=metadata,
            )
 
            existing_nodes.extend([subj_node, obj_node])
            existing_relations.append(rel_node)
 
        node.metadata[KG_NODES_KEY] = existing_nodes
        node.metadata[KG_RELATIONS_KEY] = existing_relations
        return node
 
    async def acall(
        self, nodes: List[BaseNode], show_progress: bool = False, **kwargs: Any
    ) -> List[BaseNode]:
        """Extract triples from nodes async."""
        jobs = []
        for node in nodes:
            jobs.append(self._aextract(node))
 
        return await run_jobs(
            jobs,
            workers=self.num_workers,
            show_progress=show_progress,
            desc="Extracting paths from text",
        )

In [10]:
KG_TRIPLET_EXTRACT_TMPL = """
-Goal-
Given a text document, identify all entities and their entity types from the text and all relationships among the identified entities.
Given the text, extract up to {max_knowledge_triplets} entity-relation triplets.

-Steps-
1. Identify all entities. For each identified entity, extract the following information:
- entity_name: Name of the entity, capitalized
- entity_type: Type of the entity
- entity_description: Comprehensive description of the entity's attributes and activities
Format each entity as ("entity")

2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other.
For each pair of related entities, extract the following information:
- source_entity: name of the source entity, as identified in step 1
- target_entity: name of the target entity, as identified in step 1
- relation: relationship between source_entity and target_entity
- relationship_description: explanation as to why you think the source entity and the target entity are related to each other

Format each relationship as ("relationship")

3. When finished, output.

-Real Data-
######################
text: {text}
######################
output:"""

In [11]:
kg_extractor = GraphRAGExtractor(
    llm=llm,
    extract_prompt=KG_TRIPLET_EXTRACT_TMPL,
    max_paths_per_chunk=2,
    parse_fn=parse_fn,
)

## Step 5: Build the Graph Index

The PropertyGraphIndex extracts entities and relationships from text using kg_extractor and stores them as nodes and edges in the GraphRAGStore.

In [13]:
import re
from llama_index.core.graph_stores import SimplePropertyGraphStore
import networkx as nx
from graspologic.partition import hierarchical_leiden
from llama_index.core.llms import ChatMessage

Matplotlib is building the font cache; this may take a moment.
  from .autonotebook import tqdm as notebook_tqdm


In [14]:
class GraphRAGStore(SimplePropertyGraphStore):
    community_summary = {}
    max_cluster_size = 5
 
    def generate_community_summary(self, text):
        """Generate summary for a given text using an LLM."""
        messages = [
            ChatMessage(
                role="system",
                content=(
                    "You are provided with a set of relationships from a knowledge graph, each represented as "
                    "entity1->entity2->relation->relationship_description. Your task is to create a summary of these "
                    "relationships. The summary should include the names of the entities involved and a concise synthesis "
                    "of the relationship descriptions. The goal is to capture the most critical and relevant details that "
                    "highlight the nature and significance of each relationship. Ensure that the summary is coherent and "
                    "integrates the information in a way that emphasizes the key aspects of the relationships."
                ),
            ),
            ChatMessage(role="user", content=text),
        ]
        response = OpenAI().chat(messages)
        clean_response = re.sub(r"^assistant:\s*", "", str(response)).strip()
        return clean_response
 
    def build_communities(self):
        """Builds communities from the graph and summarizes them."""
        nx_graph = self._create_nx_graph()
        community_hierarchical_clusters = hierarchical_leiden(
            nx_graph, max_cluster_size=self.max_cluster_size
        )
        community_info = self._collect_community_info(
            nx_graph, community_hierarchical_clusters
        )
        self._summarize_communities(community_info)
 
    def _create_nx_graph(self):
        """Converts internal graph representation to NetworkX graph."""
        nx_graph = nx.Graph()
        for node in self.graph.nodes.values():
            nx_graph.add_node(str(node))
        for relation in self.graph.relations.values():
            nx_graph.add_edge(
                relation.source_id,
                relation.target_id,
                relationship=relation.label,
                description=relation.properties["relationship_description"],
            )
        return nx_graph
 
    def _collect_community_info(self, nx_graph, clusters):
        """Collect detailed information for each node based on their community."""
        community_mapping = {item.node: item.cluster for item in clusters}
        community_info = {}
        for item in clusters:
            cluster_id = item.cluster
            node = item.node
            if cluster_id not in community_info:
                community_info[cluster_id] = []
 
            for neighbor in nx_graph.neighbors(node):
                if community_mapping[neighbor] == cluster_id:
                    edge_data = nx_graph.get_edge_data(node, neighbor)
                    if edge_data:
                        detail = f"{node} -> {neighbor} -> {edge_data['relationship']} -> {edge_data['description']}"
                        community_info[cluster_id].append(detail)
        return community_info
 
    def _summarize_communities(self, community_info):
        """Generate and store summaries for each community."""
        for community_id, details in community_info.items():
            details_text = (
                "\n".join(details) + "."
            )  # Ensure it ends with a period
            self.community_summary[
                community_id
            ] = self.generate_community_summary(details_text)
 
    def get_community_summaries(self):
        """Returns the community summaries, building them if not already done."""
        if not self.community_summary:
            self.build_communities()
        return self.community_summary

In [15]:
from llama_index.core import PropertyGraphIndex

In [16]:
index = PropertyGraphIndex(
    nodes=nodes,
    property_graph_store=GraphRAGStore(),
    kg_extractors=[kg_extractor],
    show_progress=True,
)

Extracting paths from text:  98%|█████████▊| 49/50 [04:39<00:04,  4.67s/it]Retrying llama_index.llms.openai.base.OpenAI._achat in 0.4856244932849 seconds as it raised RateLimitError: Error code: 429 - {'error': {'message': 'Rate limit reached for gpt-4 in organization org-pDVFuZJhGoAEImgHgJddC3bv on tokens per min (TPM): Limit 10000, Used 8880, Requested 1336. Please try again in 1.296s. Visit https://platform.openai.com/account/rate-limits to learn more.', 'type': 'tokens', 'param': None, 'code': 'rate_limit_exceeded'}}.
Extracting paths from text: 100%|██████████| 50/50 [06:06<00:00,  7.32s/it]
Generating embeddings: 100%|██████████| 1/1 [00:08<00:00,  8.52s/it]
Generating embeddings: 100%|██████████| 1/1 [00:01<00:00,  1.01s/it]


## Step 6: Detect Communities and Summarize

Use graspologic’s Hierarchical Leiden algorithm to detect communities and generate summaries. Communities are groups of nodes (entities) that are densely connected internally but sparsely connected to other groups. This algorithm maximizes a metric called modularity, which measures the quality of dividing a graph into communities.

In [17]:
index.property_graph_store.build_communities()



## Step 7: Query the Graph

Initialize the GraphRAGQueryEngine to query the processed data. When a query is submitted, the engine retrieves relevant community summaries from the GraphRAGStore. For each summary, it uses the LLM to generate a specific answer contextualized to the query via the generate_answer_from_summary method. These partial answers are then synthesized into a coherent final response using the aggregate_answers method, where the LLM combines multiple perspectives into a concise output.

In [18]:
from llama_index.core.query_engine import CustomQueryEngine
from llama_index.core.llms import LLM

In [19]:
class GraphRAGQueryEngine(CustomQueryEngine):
    graph_store: GraphRAGStore
    llm: LLM
 
    def custom_query(self, query_str: str) -> str:
        """Process all community summaries to generate answers to a specific query."""
        community_summaries = self.graph_store.get_community_summaries()
        community_answers = [
            self.generate_answer_from_summary(community_summary, query_str)
            for _, community_summary in community_summaries.items()
        ]
 
        final_answer = self.aggregate_answers(community_answers)
        return final_answer
 
    def generate_answer_from_summary(self, community_summary, query):
        """Generate an answer from a community summary based on a given query using LLM."""
        prompt = (
            f"Given the community summary: {community_summary}, "
            f"how would you answer the following query? Query: {query}"
        )
        messages = [
            ChatMessage(role="system", content=prompt),
            ChatMessage(
                role="user",
                content="I need an answer based on the above information.",
            ),
        ]
        response = self.llm.chat(messages)
        cleaned_response = re.sub(r"^assistant:\s*", "", str(response)).strip()
        return cleaned_response
 
    def aggregate_answers(self, community_answers):
        """Aggregate individual community answers into a final, coherent response."""
        # intermediate_text = " ".join(community_answers)
        prompt = "Combine the following intermediate answers into a final, concise response."
        messages = [
            ChatMessage(role="system", content=prompt),
            ChatMessage(
                role="user",
                content=f"Intermediate answers: {community_answers}",
            ),
        ]
        final_response = self.llm.chat(messages)
        cleaned_final_response = re.sub(
            r"^assistant:\s*", "", str(final_response)
        ).strip()
        return cleaned_final_response

In [20]:
query_engine = GraphRAGQueryEngine(
    graph_store=index.property_graph_store, llm=llm
)
response = query_engine.query("What are news related to financial sector?")
display(Markdown(f"{response.response}"))

The provided summary does not contain any specific news or information related to the financial sector.