# GraphRAG with LlamaIndex Implementation

The concept of GraphRAG was first introduced in the research paper titled **“From Local to Global: A Graph RAG Approach to Query-Focused Summarization,”** by Darren Edge, Ha Trinh, Newman Cheng, Joshua Bradley, Alex Chao, Apurva Mody, Steven Truitt, Jonathan Larson, published in April 2024.  

This paper addresses the limitations of traditional Retrieval-Augmented Generation (RAG) systems when handling global queries over extensive text corpora. The authors propose a novel approach that integrates large language models (LLMs) with graph-based text indexing to construct an entity knowledge graph from source documents. By generating community summaries for groups of related entities, the system can produce comprehensive and diverse responses to broad, sensemaking questions. Evaluations demonstrate that this Graph RAG method significantly outperforms baseline RAG models in both comprehensiveness and diversity of answers, particularly for large datasets.

The goal of this project is to reproduce "GraphRAG Implementation with LlamaIndex", which is available at this [link.](https://docs.llamaindex.ai/en/stable/examples/cookbooks/GraphRAG_v1/)

## References

1. Research Paper **“From Local to Global: A Graph RAG Approach to Query-Focused Summarization”** by Darren Edge, Ha Trinh, Newman Cheng, Joshua Bradley, Alex Chao, Apurva Mody, Steven Truitt, Jonathan Larson, published in April 2024 and available at this [link.](https://arxiv.org/abs/2404.16130?utm_source=chatgpt.com)

2. "GraphRAG Implementation with LlamaIndex" project, which is available at LlamaIndex.ai at this [link.](https://docs.llamaindex.ai/en/stable/examples/cookbooks/GraphRAG_v1/)


### Setup

In [None]:
# Mounting to Google Drive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
cd "YOUR-PATH-HERE"

In [None]:
%%capture
!pip install llama-index graspologic numpy==1.24.4 scipy==1.12.0

### Load Data

In [None]:
import pandas as pd
from llama_index.core import Document

news = pd.read_csv(
    "https://raw.githubusercontent.com/tomasonjo/blog-datasets/main/news_articles.csv"
)[:50]

news.head()

Unnamed: 0,title,date,text
0,Chevron: Best Of Breed,2031-04-06T01:36:32.000000000+00:00,JHVEPhoto Like many companies in the O&G secto...
1,FirstEnergy (NYSE:FE) Posts Earnings Results,2030-04-29T06:55:28.000000000+00:00,FirstEnergy (NYSE:FE – Get Rating) posted its ...
2,Dáil almost suspended after Sinn Féin TD put p...,2023-06-15T14:32:11.000000000+00:00,The Dáil was almost suspended on Thursday afte...
3,Epic’s latest tool can animate hyperrealistic ...,2023-06-15T14:00:00.000000000+00:00,"Today, Epic is releasing a new tool designed t..."
4,"EU to Ban Huawei, ZTE from Internal Commission...",2023-06-15T13:50:00.000000000+00:00,The European Commission is planning to ban equ...


In [None]:
# preprocesses documents, converts a DataFrame of news articles into a list of Document objects so that LlamaIndex can process and index them for retrieval-augmented generation (RAG) tasks.

documents = [
    Document(text=f"{row['title']}: {row['text']}")
    for i, row in news.iterrows()
]

### Setup API Key and LLM¶

In [None]:
import os

os.environ["OPENAI_API_KEY"] = "YOUR-OPENAI-API-KEY-HERE"

from llama_index.llms.openai import OpenAI

llm = OpenAI(model="gpt-3.5-turbo")

### GraphRAG

According to the LlamaIndex GraphRAG project, there are the following steps in extraction process:

 "For each input node (chunk of text):

* It sends the text to the LLM along with the extraction prompt.
* The LLM's response is parsed to extract entities, relationships, descriptions for entities and relations.
* Entities are converted into EntityNode objects. Entity description is stored in metadata
* Relationships are converted into Relation objects. Relationship description is stored in metadata.
* These are added to the node's metadata under KG_NODES_KEY and KG_RELATIONS_KEY."

In [None]:
import asyncio
import nest_asyncio

nest_asyncio.apply()

from typing import Any, List, Callable, Optional, Union, Dict
from IPython.display import Markdown, display

from llama_index.core.async_utils import run_jobs
from llama_index.core.indices.property_graph.utils import (
    default_parse_triplets_fn,
)
from llama_index.core.graph_stores.types import (
    EntityNode,
    KG_NODES_KEY,
    KG_RELATIONS_KEY,
    Relation,
)
from llama_index.core.llms.llm import LLM
from llama_index.core.prompts import PromptTemplate
from llama_index.core.prompts.default_prompts import (
    DEFAULT_KG_TRIPLET_EXTRACT_PROMPT,
)
from llama_index.core.schema import TransformComponent, BaseNode
from llama_index.core.bridge.pydantic import BaseModel, Field

In [None]:
class GraphRAGExtractor(TransformComponent):
    """Extract triples from a graph.

    Uses an LLM and a simple prompt + output parsing to extract paths (i.e. triples) and entity, relation descriptions from text.

    Args:
        llm (LLM):
            The language model to use.
        extract_prompt (Union[str, PromptTemplate]):
            The prompt to use for extracting triples.
        parse_fn (callable):
            A function to parse the output of the language model.
        num_workers (int):
            The number of workers to use for parallel processing.
        max_paths_per_chunk (int):
            The maximum number of paths to extract per chunk.
    """

    llm: LLM
    extract_prompt: PromptTemplate
    parse_fn: Callable
    num_workers: int
    max_paths_per_chunk: int

    def __init__(
        self,
        llm: Optional[LLM] = None,
        extract_prompt: Optional[Union[str, PromptTemplate]] = None,
        parse_fn: Callable = default_parse_triplets_fn,
        max_paths_per_chunk: int = 10,
        num_workers: int = 4,
    ) -> None:
        """Init params."""
        from llama_index.core import Settings

        if isinstance(extract_prompt, str):
            extract_prompt = PromptTemplate(extract_prompt)

        super().__init__(
            llm=llm or Settings.llm,
            extract_prompt=extract_prompt or DEFAULT_KG_TRIPLET_EXTRACT_PROMPT,
            parse_fn=parse_fn,
            num_workers=num_workers,
            max_paths_per_chunk=max_paths_per_chunk,
        )

    @classmethod
    def class_name(cls) -> str:
        return "GraphExtractor"

    def __call__(
        self, nodes: List[BaseNode], show_progress: bool = False, **kwargs: Any
    ) -> List[BaseNode]:
        """Extract triples from nodes."""
        return asyncio.run(
            self.acall(nodes, show_progress=show_progress, **kwargs)
        )

    async def _aextract(self, node: BaseNode) -> BaseNode:
        """Extract triples from a node."""
        assert hasattr(node, "text")

        text = node.get_content(metadata_mode="llm")
        try:
            llm_response = await self.llm.apredict(
                self.extract_prompt,
                text=text,
                max_knowledge_triplets=self.max_paths_per_chunk,
            )
            entities, entities_relationship = self.parse_fn(llm_response)
        except ValueError:
            entities = []
            entities_relationship = []

        existing_nodes = node.metadata.pop(KG_NODES_KEY, [])
        existing_relations = node.metadata.pop(KG_RELATIONS_KEY, [])
        metadata = node.metadata.copy()
        for entity, entity_type, description in entities:
            metadata[
                "entity_description"
            ] = description  # Not used in the current implementation. But will be useful in future work.
            entity_node = EntityNode(
                name=entity, label=entity_type, properties=metadata
            )
            existing_nodes.append(entity_node)

        metadata = node.metadata.copy()
        for triple in entities_relationship:
            subj, rel, obj, description = triple
            subj_node = EntityNode(name=subj, properties=metadata)
            obj_node = EntityNode(name=obj, properties=metadata)
            metadata["relationship_description"] = description
            rel_node = Relation(
                label=rel,
                source_id=subj_node.id,
                target_id=obj_node.id,
                properties=metadata,
            )

            existing_nodes.extend([subj_node, obj_node])
            existing_relations.append(rel_node)

        node.metadata[KG_NODES_KEY] = existing_nodes
        node.metadata[KG_RELATIONS_KEY] = existing_relations
        return node

    async def acall(
        self, nodes: List[BaseNode], show_progress: bool = False, **kwargs: Any
    ) -> List[BaseNode]:
        """Extract triples from nodes async."""
        jobs = []
        for node in nodes:
            jobs.append(self._aextract(node))

        return await run_jobs(
            jobs,
            workers=self.num_workers,
            show_progress=show_progress,
            desc="Extracting paths from text",
        )

### GraphRAGStore
This class applies community detection algorithms to combine similar nodes into a community and to generate summaries for every community with LLM.

In [None]:
import re
from llama_index.core.graph_stores import SimplePropertyGraphStore
import networkx as nx
from graspologic.partition import hierarchical_leiden

from llama_index.core.llms import ChatMessage


class GraphRAGStore(SimplePropertyGraphStore):
    community_summary = {}
    max_cluster_size = 5

    def generate_community_summary(self, text):
        """Generate summary for a given text using an LLM."""
        messages = [
            ChatMessage(
                role="system",
                content=(
                    "You are provided with a set of relationships from a knowledge graph, each represented as "
                    "entity1->entity2->relation->relationship_description. Your task is to create a summary of these "
                    "relationships. The summary should include the names of the entities involved and a concise synthesis "
                    "of the relationship descriptions. The goal is to capture the most critical and relevant details that "
                    "highlight the nature and significance of each relationship. Ensure that the summary is coherent and "
                    "integrates the information in a way that emphasizes the key aspects of the relationships."
                ),
            ),
            ChatMessage(role="user", content=text),
        ]
        response = OpenAI().chat(messages)
        clean_response = re.sub(r"^assistant:\s*", "", str(response)).strip()
        return clean_response

    def build_communities(self):
        """Builds communities from the graph and summarizes them."""
        nx_graph = self._create_nx_graph()
        community_hierarchical_clusters = hierarchical_leiden(
            nx_graph, max_cluster_size=self.max_cluster_size
        )
        community_info = self._collect_community_info(
            nx_graph, community_hierarchical_clusters
        )
        self._summarize_communities(community_info)

    def _create_nx_graph(self):
        """Converts internal graph representation to NetworkX graph."""
        nx_graph = nx.Graph()
        for node in self.graph.nodes.values():
            nx_graph.add_node(str(node))
        for relation in self.graph.relations.values():
            nx_graph.add_edge(
                relation.source_id,
                relation.target_id,
                relationship=relation.label,
                description=relation.properties["relationship_description"],
            )
        return nx_graph

    def _collect_community_info(self, nx_graph, clusters):
        """Collect detailed information for each node based on their community."""
        community_mapping = {item.node: item.cluster for item in clusters}
        community_info = {}
        for item in clusters:
            cluster_id = item.cluster
            node = item.node
            if cluster_id not in community_info:
                community_info[cluster_id] = []

            for neighbor in nx_graph.neighbors(node):
                if community_mapping[neighbor] == cluster_id:
                    edge_data = nx_graph.get_edge_data(node, neighbor)
                    if edge_data:
                        detail = f"{node} -> {neighbor} -> {edge_data['relationship']} -> {edge_data['description']}"
                        community_info[cluster_id].append(detail)
        return community_info

    def _summarize_communities(self, community_info):
        """Generate and store summaries for each community."""
        for community_id, details in community_info.items():
            details_text = (
                "\n".join(details) + "."
            )  # Ensure it ends with a period
            self.community_summary[
                community_id
            ] = self.generate_community_summary(details_text)

    def get_community_summaries(self):
        """Returns the community summaries, building them if not already done."""
        if not self.community_summary:
            self.build_communities()
        return self.community_summary

### GraphRAGQueryEngine

This class is "a custom query engine designed to process queries using the GraphRAG approach. It leverages the community summaries generated by the GraphRAGStore to answer user queries."

In [None]:
from llama_index.core.query_engine import CustomQueryEngine
from llama_index.core.llms import LLM


class GraphRAGQueryEngine(CustomQueryEngine):
    graph_store: GraphRAGStore
    llm: LLM

    def custom_query(self, query_str: str) -> str:
        """Process all community summaries to generate answers to a specific query."""
        community_summaries = self.graph_store.get_community_summaries()
        community_answers = [
            self.generate_answer_from_summary(community_summary, query_str)
            for _, community_summary in community_summaries.items()
        ]

        final_answer = self.aggregate_answers(community_answers)
        return final_answer

    def generate_answer_from_summary(self, community_summary, query):
        """Generate an answer from a community summary based on a given query using LLM."""
        prompt = (
            f"Given the community summary: {community_summary}, "
            f"how would you answer the following query? Query: {query}"
        )
        messages = [
            ChatMessage(role="system", content=prompt),
            ChatMessage(
                role="user",
                content="I need an answer based on the above information.",
            ),
        ]
        response = self.llm.chat(messages)
        cleaned_response = re.sub(r"^assistant:\s*", "", str(response)).strip()
        return cleaned_response

    def aggregate_answers(self, community_answers):
        """Aggregate individual community answers into a final, coherent response."""
        # intermediate_text = " ".join(community_answers)
        prompt = "Combine the following intermediate answers into a final, concise response."
        messages = [
            ChatMessage(role="system", content=prompt),
            ChatMessage(
                role="user",
                content=f"Intermediate answers: {community_answers}",
            ),
        ]
        final_response = self.llm.chat(messages)
        cleaned_final_response = re.sub(
            r"^assistant:\s*", "", str(final_response)
        ).strip()
        return cleaned_final_response

### GraphRAG Pipeline

In [None]:
# Create chuncks from the text

from llama_index.core.node_parser import SentenceSplitter

splitter = SentenceSplitter(
    chunk_size=1024,
    chunk_overlap=20,
)
nodes = splitter.get_nodes_from_documents(documents)

In [None]:
len(nodes)

50

In [None]:
# Build GraphIndex

KG_TRIPLET_EXTRACT_TMPL = """
-Goal-
Given a text document, identify all entities and their entity types from the text and all relationships among the identified entities.
Given the text, extract up to {max_knowledge_triplets} entity-relation triplets.

-Steps-
1. Identify all entities. For each identified entity, extract the following information:
- entity_name: Name of the entity, capitalized
- entity_type: Type of the entity
- entity_description: Comprehensive description of the entity's attributes and activities
Format each entity as ("entity"$$$$"<entity_name>"$$$$"<entity_type>"$$$$"<entity_description>")

2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other.
For each pair of related entities, extract the following information:
- source_entity: name of the source entity, as identified in step 1
- target_entity: name of the target entity, as identified in step 1
- relation: relationship between source_entity and target_entity
- relationship_description: explanation as to why you think the source entity and the target entity are related to each other

Format each relationship as ("relationship"$$$$"<source_entity>"$$$$"<target_entity>"$$$$"<relation>"$$$$"<relationship_description>")

3. When finished, output.

-Real Data-
######################
text: {text}
######################
output:"""

In [None]:
entity_pattern = r'\("entity"\$\$\$\$"(.+?)"\$\$\$\$"(.+?)"\$\$\$\$"(.+?)"\)'
relationship_pattern = r'\("relationship"\$\$\$\$"(.+?)"\$\$\$\$"(.+?)"\$\$\$\$"(.+?)"\$\$\$\$"(.+?)"\)'


def parse_fn(response_str: str) -> Any:
    entities = re.findall(entity_pattern, response_str)
    relationships = re.findall(relationship_pattern, response_str)
    return entities, relationships


kg_extractor = GraphRAGExtractor(
    llm=llm,
    extract_prompt=KG_TRIPLET_EXTRACT_TMPL,
    max_paths_per_chunk=2,
    parse_fn=parse_fn,
)

In [None]:
from llama_index.core import PropertyGraphIndex

index = PropertyGraphIndex(
    nodes=nodes,
    property_graph_store=GraphRAGStore(),
    kg_extractors=[kg_extractor],
    show_progress=True,
)

Extracting paths from text: 100%|██████████| 50/50 [05:02<00:00,  6.04s/it]
Generating embeddings: 100%|██████████| 1/1 [00:01<00:00,  1.30s/it]
Generating embeddings: 100%|██████████| 6/6 [00:00<00:00,  6.15it/s]


In [None]:
list(index.property_graph_store.graph.nodes.values())[-1]

EntityNode(label='entity', embedding=None, properties={'relationship_description': 'Liverpool is scheduled to play against Wolves in the upcoming Premier League season.', 'triplet_source_id': '26101669-9e82-41d5-b787-651c4cf1e2b8'}, name='Competes against')

In [None]:
list(index.property_graph_store.graph.relations.values())[0]

Relation(label='NYSE:CVX', source_id='Chevron', target_id='Ownership', properties={'relationship_description': 'Chevron owns the stock NYSE:CVX. The performance of NYSE:CVX directly affects the financial status of Chevron.', 'triplet_source_id': 'd36c271b-202e-46e2-bd19-7ab525716fb5'})

In [None]:
list(index.property_graph_store.graph.relations.values())[0].properties[
    "relationship_description"
]

'Chevron owns the stock NYSE:CVX. The performance of NYSE:CVX directly affects the financial status of Chevron.'

### Build Communities

In [None]:
index.property_graph_store.build_communities()



### Create Query Engine

In [None]:
query_engine = GraphRAGQueryEngine(
    graph_store=index.property_graph_store, llm=llm
)

In [None]:
response = query_engine.query(
    "What are the main news discussed in the document?"
)
display(Markdown(f"{response.response}"))

The documents discuss a variety of news topics. These include FirstEnergy being a publicly traded company, a conflict between Sinn Féin TD John Brady and Minister for Housing Darragh O’Brien, the dispute between the Retained Firefighters and the Minister for Housing, the use of the MetaHuman Animator tool in Unreal Engine, Radivoje Bukvić starring in the short film "Blue Dot", the European Commission's decision to ban the use of TikTok Inc. by its staff, the deterioration of the relationship between the US and China, Arsenal's bid for the football player Rice, Jude Bellingham's transfer from Borussia Dortmund to Real Madrid, the termination of Aidan Murray's employment by Ryanair, KeyBank's ownership of the American Fork Branch, the teasing of a new smartphone, the Vivo X90s, the rivalry between XPeng and Tesla, the contract negotiations between player Maliek Collins and the Houston Texans, the acquisition of The Hollies' recording catalog by BMG, the distribution pact between ADA Worldwide and Rostrum Records, the marketing and distribution deal between Believe and Global Records Germany, the business relationship between Supplier.io and Hyatt Hotels, the recognition of Supplier.io and Hyatt Hotels as recipients of the 2023 Top Supply Chain Projects award, Gordon McQueen's football career, the supplier-customer relationship between GE Vernova and Amplus Solar, the development of the game "Star Ocean: The Second Story R" by Square Enix, the remastering of the game "Star Ocean: First Departure R" by Square Enix, the financial status of major financial institutions, Paytm's role as the provider of the Fastag service, Nirmal Bang's Buy rating to Tata Chemicals Ltd., JetBlue's ownership of an aircraft named A Defining MoMint, Coinbase Global Inc.'s repurchase of convertible senior notes, Vincent Kompany managing Burnley, Anil Goteti founding a fintech venture called Scapia, Manchester United and Chelsea considering André Onana as a potential transfer target, Thomas Christl's employment by Morgan Stanley, the launch of the Redmi 12C in India, the market rivalry between Xiaomi and Apple, the upcoming launch of the Hyundai Exter, Tarun Garg's role as the COO of Hyundai Motor India Limited, Generation Investment Management's holding of Henry Schein, Inc., Deutsche Bank's analyses of Allegiant Travel and SkyWest, the impacts of the COVID-19 pandemic on Delta Air Lines and Southwest Airlines, Manchester City, Real Madrid, and Inter Milan competing in the UEFA Champions League, Stellantis' decision to close the Belvidere Assembly Plant, the collaboration between J.B. Pritzker's administration and the Department of Commerce and Economic Opportunity, Tesla Inc. and General Motors Co.'s discussions and collaboration, the development of the MetaHuman Animator tool by Epic, the launch of a new digital campaign called #HameshaKaHiKyun by Sony SAB, the dismissal of Aidan Murray from his position at Ryanair, Bank of America's investment in Chingona Ventures, Coinbase Global, Inc.'s plan to repurchase its convertible senior notes, the competitive dynamics between UnitedHealth Group Inc. and Humana Inc., Uber's plan to exit Israel, Vincent Kompany's role in Manchester City's history, MDA Ltd.'s CFO, Vito Culmone, representing the company at the Jefferies Virtual Space Summit, and football player DeAndre Hopkins in talks with the Titans for a potential signing.

In [None]:
response = query_engine.query("What are news related to financial sector?")
display(Markdown(f"{response.response}"))

The provided information and summaries largely do not contain any news related to the financial sector. However, there are a few exceptions. KeyBank has expanded in the Western U.S. with a new branch and donated $10,000 to the Five.12 Foundation. BMG has acquired the recording catalog of The Hollies. Major financial institutions have reported their financial status to S&P Global Inc. Nirmal Bang has given a Buy rating to Tata Chemicals Ltd. Coinbase Global Inc. repurchased $64.5 million worth of 0.50% convertible senior notes. Morgan Stanley has hired Thomas Christl to co-lead its European consumer and retail client coverage. Deutsche Bank conducted analyses on Allegiant Travel and SkyWest. The COVID-19 pandemic significantly impacted Delta Air Lines and Southwest Airlines. Ihor Dusaniwsky of S3 Partners analyzed Tesla's stock trends. Bank of America has invested in Chingona Ventures. Coinbase Global, Inc. is planning to repurchase its 0.50% Convertible Senior Notes due 2026. UnitedHealth Group Inc. holds a significant position in the financial market.