#### **GraphRAG: Graph-Enhanced Retrieval-Augmented Generation**

- GraphRAG is an advanced question-answering system that combines the power of graph-based knowledge representation with retrieval-augmented generation. 
- It processes input documents to create a rich knowledge graph, which is then used to enhance the retrieval and generation of answers to user queries. 
- The system leverages natural language processing, machine learning, and graph theory to provide more accurate and contextually relevant responses.

#### **Motivation**

Traditional retrieval-augmented generation systems often struggle with maintaining context over long documents and making connections between related pieces of information. GraphRAG addresses these limitations by:

- Representing knowledge as an interconnected graph, allowing for better preservation of relationships between concepts.
- Enabling more intelligent traversal of information during the query process.
Providing a visual representation of how information is connected and accessed during the answering process.

#### **Key Components**

- **DocumentProcessor**: Handles the initial processing of input documents, creating text chunks and embeddings.

- **KnowledgeGraph**: Constructs a graph representation of the processed documents, where nodes represent text chunks and edges represent relationships between them.

- **QueryEngine**: Manages the process of answering user queries by leveraging the knowledge graph and vector store.

- **Visualizer**: Creates a visual representation of the graph and the traversal path taken to answer a query.

#### **Method Details**

4) #### **Vizualization** 
    - The knowledge graph is visualized with nodes representing text chunks and edges representing relationships.
    - Edge colors indicate the strength of relationships (weights).
    - The traversal path taken to answer a query is highlighted with curved, dashed arrows.
    - Start and end nodes of the traversal are distinctly colored for easy identification.

#### **Conclusion**

- GraphRAG represents a significant advancement in retrieval-augmented generation systems. By incorporating a graph-based knowledge representation and intelligent traversal mechanisms, it offers improved context awareness, more accurate retrieval, and enhanced explainability. 

- The system's ability to visualize its decision-making process provides valuable insights into its operation, making it a powerful tool for both end-users and developers.

- As natural language processing and graph-based AI continue to evolve, systems like GraphRAG pave the way for more sophisticated and capable question-answering technologies.

In [47]:
import networkx as nx

from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.prompts import PromptTemplate
# from langchain.retrievers import ContextualCompressionRetriever
# from langchain.retrievers.document_compressors import LLMChainExtractor

# from langchain.callbacks import get_openai_callback
from sklearn.metrics.pairwise import cosine_similarity
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import os
import sys
from dotenv import load_dotenv
from langchain_openai import ChatOpenAI
from typing import List, Tuple, Dict
from nltk.stem import WordNetLemmatizer
from nltk.tokenize import word_tokenize
import nltk
import spacy
import heapq

from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
import numpy as np

from spacy.cli import download
from spacy.lang.en import English

# Original path append replaced for Colab compatibility
# from helper_functions import *
# from evaluation.evalute_rag import *

# Load environment variables from a .env file
load_dotenv()

nltk.download('punkt', quiet=True)
nltk.download('wordnet', quiet=True)

True

---

#### **LLM used**

In [18]:
from langchain_ollama import ChatOllama 

llm = ChatOllama(
    model='llama3.2',
    verbose=True,
    temperature=0.2
)

llm.invoke("Hey How are you?").content

"I'm just a language model, so I don't have emotions or feelings like humans do. However, I'm functioning properly and ready to help with any questions or tasks you may have! How can I assist you today?"

---

#### **Embedding Model** 

In [22]:
from langchain_huggingface import HuggingFaceEmbeddings 

embedding_model = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")

embeddings = embedding_model.embed_query("Hey How are you?")
print(f"Length of embeddings : {len(embeddings)}")
print(f"Sample embeddings : {embeddings[:100]}")

Length of embeddings : 384
Sample embeddings : [-0.013380538672208786, 0.003255972173064947, 0.10806030035018921, 0.08322358131408691, 0.02040085941553116, -0.049066152423620224, 0.0722508355975151, 0.002980925841256976, -0.08823534101247787, 0.016058299690485, -0.03367079421877861, -4.332493062975118e-06, -0.02510129101574421, 0.0007887802203185856, 0.060331884771585464, -0.0415474958717823, 0.07702311128377914, -0.14256997406482697, -0.13958506286144257, 0.06023767963051796, 0.003192346775904298, 0.018982844427227974, 0.02300790697336197, 0.06056844815611839, -0.07911035418510437, -0.05399537831544876, -0.0008475205395370722, 0.03202424943447113, -0.029674910008907318, -0.04484577104449272, -0.10411098599433899, 0.06399180740118027, -0.05713418126106262, -0.02695028856396675, -0.028776653110980988, 0.00333896791562438, -0.0355900302529335, -0.13525626063346863, 0.009469274431467056, 0.0003555373114068061, 0.009924577549099922, -0.0014938903041183949, -0.009747199714183807, -0.0021706

---

#### **Loading the Documents**

In [57]:
from langchain_community.document_loaders import PyPDFLoader

file_path = "../data/Understanding_Climate_Change.pdf"

loader = PyPDFLoader(file_path)
docs = loader.load()
print(f"Number of Docs : {len(docs)}")

Number of Docs : 33


--- 

#### **Document Processor**
- Input documents are split into manageable chunks.
- Each chunk is embedded using a language model.
- A vector store is created from these embeddings for efficient similarity search.

In [58]:
import faiss 
from langchain_community.docstore.in_memory import InMemoryDocstore
from langchain_community.vectorstores import FAISS

# Define the DocumentProcessor class
class DocumentProcessor:
    def __init__(self, llm, embedding_model):
        self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=400, chunk_overlap=50)
        self.llm = llm
        self.embedding_model = embedding_model 

    def process_documents(self, documents):
        splits = self.text_splitter.split_documents(documents)
        vector_store = FAISS.from_documents(splits, self.embedding_model)
        return splits, vector_store
    
    def create_embeddings_batch(self, texts, batch_size=32):
        embeddings = []
        for i in range(0, len(texts), batch_size):
            batch = texts[i:i+batch_size]
            batch_embeddings = self.embeddings.embed_documents(batch)
            embeddings.extend(batch_embeddings)
        return np.array(embeddings)

    def compute_similarity_matrix(self, embeddings):
        return cosine_similarity(embeddings)

---

#### **Knowledge Graph Construction:** 

- Graph nodes are created for each text chunk.
- Concepts are extracted from each chunk using a combination of NLP techniques and language models.
- Extracted concepts are lemmatized to improve matching.
- Edges are added between nodes based on semantic similarity and shared concepts.
- Edge weights are calculated to represent the strength of relationships.

In [59]:
# Define the Concepts class
import spacy
from spacy.cli import download
from pydantic import BaseModel, Field
from typing import List, Annotated
import networkx as nx 

# Data Validation Model 
class Concepts(BaseModel):
    concepts_list: Annotated[List[str], Field(description="List of concepts")]
    
# Define the KnowledgeGraph class
class KnowledgeGraph:
    def __init__(self):
        self.graph = nx.Graph() # networkx Graph 
        self.lemmatizer = WordNetLemmatizer() # An instance of WordNetLemmatizer
        self.concept_cache = {} # A dictionary to cache extracted concepts (content maps to concept)
        self.nlp = self._load_spacy_model()  
        self.edges_threshold = 0.8 
    
    def build_graph(self, splits, llm, embedding_model):
        self._add_nodes(splits)
        embeddings = self._create_embeddings(splits, embedding_model)
        self._extract_concepts(splits, llm)
        self._add_edges(embeddings)

    def _add_nodes(self, splits):
        "Adding nodes to the Graph."
        for i, split in enumerate(splits):
            self.graph.add_node(i, content=split.page_content)

    def _create_embeddings(self, splits, embedding_model):
        "Get embeddings for each text chunk we have."
        texts = [split.page_content for split in splits]
        return embedding_model.embed_documents(texts)

    def _compute_similarities(self, embeddings):
        "Cosine similarity of embeddings"
        return cosine_similarity(embeddings)

    def _load_spacy_model(self):
        "Load the Spacy Model."
        try:
            return spacy.load("en_core_web_sm")
        except OSError:
            print("Downloading spaCy model...")
            download("en_core_web_sm")
            return spacy.load("en_core_web_sm")

    def _extract_concepts_and_entities(self, content, llm):
        if content in self.concept_cache:
            return self.concept_cache[content]
        # Extract named entities using spaCy
        doc = self.nlp(content)
        named_entities = [ent.text for ent in doc.ents if ent.label_ in ["PERSON", "ORG", "GPE", "WORK_OF_ART"]]
        # Extract general concepts using LLM
        concept_extraction_prompt = PromptTemplate(
            input_variables=["text"],
            template="Extract key concepts (excluding named entities) from the following text:\n\n{text}\n\nKey concepts:"
        )
        concept_chain = concept_extraction_prompt | llm.with_structured_output(Concepts)
        general_concepts = concept_chain.invoke({"text": content}).concepts_list
        
        # Combine named entities and general concepts
        all_concepts = list(set(named_entities + general_concepts))
        self.concept_cache[content] = all_concepts
        return all_concepts

    def _extract_concepts(self, splits, llm):
        with ThreadPoolExecutor() as executor:
            future_to_node = {executor.submit(self._extract_concepts_and_entities, split.page_content, llm): i 
                              for i, split in enumerate(splits)}
            
            for future in tqdm(as_completed(future_to_node), total=len(splits), desc="Extracting concepts and entities"):
                node = future_to_node[future]
                concepts = future.result()
                self.graph.nodes[node]['concepts'] = concepts

    def _add_edges(self, embeddings):
        similarity_matrix = self._compute_similarities(embeddings)
        num_nodes = len(self.graph.nodes)
        
        for node1 in tqdm(range(num_nodes), desc="Adding edges"):
            for node2 in range(node1 + 1, num_nodes):
                similarity_score = similarity_matrix[node1][node2]
                if similarity_score > self.edges_threshold:
                    shared_concepts = set(self.graph.nodes[node1]['concepts']) & set(self.graph.nodes[node2]['concepts'])
                    edge_weight = self._calculate_edge_weight(node1, node2, similarity_score, shared_concepts)
                    self.graph.add_edge(node1, node2, weight=edge_weight, 
                                        similarity=similarity_score,
                                        shared_concepts=list(shared_concepts))

    def _calculate_edge_weight(self, node1, node2, similarity_score, shared_concepts, alpha=0.7, beta=0.3):
        max_possible_shared = min(len(self.graph.nodes[node1]['concepts']), len(self.graph.nodes[node2]['concepts']))
        normalized_shared_concepts = len(shared_concepts) / max_possible_shared if max_possible_shared > 0 else 0
        return alpha * similarity_score + beta * normalized_shared_concepts

    def _lemmatize_concept(self, concept):
        return ' '.join([self.lemmatizer.lemmatize(word) for word in concept.lower().split()])

---

#### **Query Engine**

- The user query is embedded and used to retrieve relevant documents from the vector store.
- A priority queue is initialized with the nodes corresponding to the most relevant documents.
- The system employs a Dijkstra-like algorithm to traverse the knowledge graph:
    - Nodes are explored in order of their priority (strength of connection to the query).
    - For each explored node:
        - Its content is added to the context.
        - The system checks if the current context provides a complete answer.
        - If the answer is incomplete:
            - The node's concepts are processed and added to a set of visited concepts.
            - Neighboring nodes are explored, with their priorities updated based on edge weights.
            - Nodes are added to the priority queue if a stronger connection is found.
- This process continues until a complete answer is found or the priority queue is exhausted.
- If no complete answer is found after traversing the graph, the system generates a final answer using the accumulated context and a large language model.

In [61]:
from typing import Tuple
import pdb

# Define the AnswerCheck class for Structured output from LLM 
class AnswerCheck(BaseModel):
    is_complete: bool = Field(description="Whether the current context provides a complete answer to the query")
    answer: str = Field(description="The current answer based on the context, if any")

# Define the QueryEngine class
class QueryEngine:
    def __init__(self, vector_store, knowledge_graph, llm):
        # it would take vector_store, knowledge_graph and llm 
        self.vector_store = vector_store
        self.knowledge_graph = knowledge_graph 
        self.llm = llm
        self.max_context_length = 4000 # defining the context size
        self.answer_check_chain = self._create_answer_check_chain()

    # this is the chain for checking whether context is enough to provide the answer to questions or not
    def _create_answer_check_chain(self):
        answer_check_prompt = PromptTemplate(
            input_variables=["query", "context"],
            template="Given the query: '{query}'\n\nAnd the current context:\n{context}\n\nDoes this context provide a complete answer to the query? If yes, provide the answer. If no, state that the answer is incomplete.\n\nIs complete answer (Yes/No):\nAnswer (if complete):"
        )
        return answer_check_prompt | self.llm.with_structured_output(AnswerCheck)

    # invoking check-answer chain
    def _check_answer(self, query: str, context: str) -> Tuple[bool, str]:
        response = self.answer_check_chain.invoke({"query": query, "context": context})
        return response.is_complete, response.answer

    def _expand_context(self, query: str, relevant_docs):
        expanded_context = ""
        traversal_path = []
        visited_nodes = set()
        visited_concepts = set()
        filtered_content = {}
        final_answer = ""

        priority_queue = []
        distances = {}

        print("\nTraversing the knowledge graph:")

        # --- Seed the queue from vector retrieval ---
        for doc in relevant_docs:
            node_doc, score = self.vector_store.similarity_search_with_score(
                doc.page_content, k=1
            )[0]

            start_node = next(
                n for n in self.knowledge_graph.graph.nodes
                if self.knowledge_graph.graph.nodes[n]["content"] == node_doc.page_content
            )

            priority = 1 / score
            distances[start_node] = priority
            heapq.heappush(priority_queue, (priority, start_node))

        step = 0

        # --- Best-first graph traversal ---
        while priority_queue:
            current_priority, current_node = heapq.heappop(priority_queue)

            # Skip stale queue entries
            if current_priority > distances.get(current_node, float("inf")):
                continue

            # Skip already expanded nodes
            if current_node in visited_nodes:
                continue

            # --- Expand node ---
            visited_nodes.add(current_node)
            traversal_path.append(current_node)
            step += 1

            node_data = self.knowledge_graph.graph.nodes[current_node]
            node_content = node_data["content"]
            node_concepts = node_data["concepts"]

            filtered_content[current_node] = node_content
            expanded_context += "\n" + node_content if expanded_context else node_content

            print(f"\nStep {step} - Node {current_node}")
            print(f"Content: {node_content[:100]}...")
            print(f"Concepts: {', '.join(node_concepts)}")
            print("-" * 50)

            # --- Early stopping check ---
            is_complete, answer = self._check_answer(query, expanded_context)
            if is_complete:
                final_answer = answer
                break

            # --- Concept gating ---
            lemmatized = {
                self.knowledge_graph._lemmatize_concept(c)
                for c in node_concepts
            }

            if lemmatized.issubset(visited_concepts):
                continue

            visited_concepts.update(lemmatized)

            # --- Queue neighbors ---
            for neighbor in self.knowledge_graph.graph.neighbors(current_node):
                edge_weight = self.knowledge_graph.graph[current_node][neighbor]["weight"]
                new_distance = current_priority + (1 / edge_weight)

                if new_distance < distances.get(neighbor, float("inf")):
                    distances[neighbor] = new_distance
                    heapq.heappush(priority_queue, (new_distance, neighbor))

        # --- Fallback answer ---
        if not final_answer:
            response_prompt = PromptTemplate(
                input_variables=["query", "context"],
                template=(
                    "Based on the following context, answer the query.\n\n"
                    "Context:\n{context}\n\nQuery:\n{query}\n\nAnswer:"
                )
            )
            response_chain = response_prompt | self.llm
            final_answer = response_chain.invoke({
                "query": query,
                "context": expanded_context
            })

        return expanded_context, traversal_path, filtered_content, final_answer


    def query(self, query: str) -> Tuple[str, List[int], Dict[int, str]]:
        print(f"\nProcessing query: {query}")
        print(f"Getting relevant documents from vectorstore ")
        relevant_docs = self._retrieve_relevant_documents(query)
        # this is some complex function
        expanded_context, traversal_path, filtered_content, final_answer = self._expand_context(query, relevant_docs)
            
        if not final_answer:
            print("\nGenerating final answer...")
            response_prompt = PromptTemplate(
                input_variables=["query", "context"],
                template="Based on the following context, please answer the query\n\nContext: {context}\n\nQuery: {query}\n\nAnswer:"
            )
                
            response_chain = response_prompt | self.llm
            input_data = {"query": query, "context": expanded_context}
            response = response_chain.invoke(input_data)
            final_answer = response
        else:
            print("\nComplete answer found during traversal.")
            
        print(f"\nFinal Answer: {final_answer}")
        
        return final_answer, traversal_path, filtered_content

    def _retrieve_relevant_documents(self, query: str):
        print("\nRetrieving relevant documents...")
        retriever = self.vector_store.as_retriever(search_type="similarity", search_kwargs={"k": 5})
        # compressor = LLMChainExtractor.from_llm(self.llm)
        # compression_retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever)
        return retriever.invoke(query)

In [62]:
import pdb 

class GraphRAG:
    def __init__(self, llm, embedding_model):
        self.llm = llm
        self.embedding_model = embedding_model
        self.document_processor = DocumentProcessor(llm, embedding_model)
        self.knowledge_graph = KnowledgeGraph()
        self.query_engine = None
        # self.visualizer = Visualizer()

    def process_documents(self, documents):
        splits, vector_store = self.document_processor.process_documents(documents)
        self.knowledge_graph.build_graph(splits, self.llm, self.embedding_model)
        self.query_engine = QueryEngine(vector_store, self.knowledge_graph, self.llm)

    def query(self, query: str):
        response, traversal_path, filtered_content = self.query_engine.query(query)
        # if traversal_path:
        #     self.visualizer.visualize_traversal(self.knowledge_graph.graph, traversal_path)
        # else:
        #     print("No traversal path to visualize.")
        return response

In [63]:
from langchain_community.document_loaders import PyPDFLoader

graph_rag = GraphRAG(llm, embedding_model)

# lets process the documents 
graph_rag.process_documents(docs)

Extracting concepts and entities: 100%|██████████| 215/215 [15:05<00:00,  4.21s/it]
Adding edges: 100%|██████████| 215/215 [00:00<00:00, 32905.50it/s]


In [64]:
graph_rag.query("What is Climate Change?")


Processing query: What is Climate Change?
Getting relevant documents from vectorstore 

Retrieving relevant documents...

Traversing the knowledge graph:

Step 1 - Node 131
Content: Chapter 14: Climate Change and the Economy 
Economic Transformation...
Concepts: Economic Transformation, Economy, the Economy 
Economic Transformation, Climate Change
--------------------------------------------------

Step 2 - Node 101
Content: impacts of climate change. This includes studying the links between climate and health, 
developing ...
Concepts: climate education, curriculum development, Advocacy 
Climate Education 
Curriculum Development, evidence-based policies and interventions, Education, health data systems, links between climate and health, new technologies and treatments
--------------------------------------------------

Step 3 - Node 4
Content: provide a historical record that scientists use to understand past climate conditions and 
predict f...
Concepts: emission, greenhouse gases, 

AIMessage(content='Climate change refers to significant, long-term changes in the global climate. The term "global climate" encompasses the planet\'s overall weather patterns, including temperature, precipitation, and wind patterns, over an extended period.', additional_kwargs={}, response_metadata={'model': 'llama3.2', 'created_at': '2025-12-31T03:29:04.600811Z', 'done': True, 'done_reason': 'stop', 'total_duration': 5797202583, 'load_duration': 161558292, 'prompt_eval_count': 329, 'prompt_eval_duration': 2688650500, 'eval_count': 43, 'eval_duration': 2943080291, 'logprobs': None, 'model_name': 'llama3.2', 'model_provider': 'ollama'}, id='lc_run--019b7273-9fb0-7a11-80b3-b7961444e1bd-0', usage_metadata={'input_tokens': 329, 'output_tokens': 43, 'total_tokens': 372})