## 🕸️ GraphRAG: Graph-Enhanced Retrieval-Augmented Generation | RAG100X

This notebook implements **GraphRAG** — an advanced retrieval strategy that combines **graph-based knowledge representation** with retrieval-augmented generation.  

Instead of treating document chunks as isolated units, GraphRAG builds a **knowledge graph** where chunks are **nodes** and their relationships (shared concepts, semantic similarity) are **edges**. This enables the system to **walk through related ideas**, preserving context and uncovering deeper connections when answering queries.

---

### ✅ What You’ll Learn

- How to build a **knowledge graph** from document chunks  
- How to represent **nodes (text)** and **edges (relationships)**  
- How queries are processed using a **graph traversal algorithm**  
- Why graphs help maintain **long-range context** across documents  
- How visualization reveals the **reasoning path** of the system  

---

### 🔍 Real-world Analogy

Imagine studying history:  

> 📖 You first read about the **French Revolution**  
> 🔗 Then you follow connections to the **Coup of 18 Brumaire**  
> 👑 Finally, you see how it links to **Napoleon becoming Emperor**

✅ That’s GraphRAG — it doesn’t just “fetch text,” it **follows the chain of connected events** to build a complete, coherent answer.

---

### 🧠 How GraphRAG Works Under the Hood

| Step                     | What Happens                                                                 |
|--------------------------|------------------------------------------------------------------------------|
| 1. Document Processing   | Split text into chunks, embed them, and store in a vector DB                 |
| 2. Graph Construction    | Create nodes (chunks), extract concepts, add edges for semantic/concept links|
| 3. Query Retrieval       | Embed query, fetch initial relevant nodes from vector DB                     |
| 4. Graph Traversal       | Use a Dijkstra-like algorithm to walk connected nodes and gather context     |
| 5. Answer Generation     | LLM composes a final response from the gathered context                      |
| 6. Visualization         | Show nodes, edges, and the traversal path for explainability                 |

💡 Unlike flat retrieval, GraphRAG “connects the dots” and **navigates knowledge like a map**.

---

### 🚀 Why GraphRAG Matters

- 🧠 **Preserves context**: Links distant but related concepts across docs  
- 🔍 **Smarter retrieval**: Goes beyond keyword match to follow relationships  
- 🔎 **Explainable**: Graph visualization shows *why* an answer was chosen  
- ⚙️ **Flexible**: Easily integrates new documents and connections  
- ⏩ **Efficient**: Prioritizes the strongest knowledge pathways  

---

### 🏗️ Use Cases Where It Shines

- Research papers with interlinked concepts  
- Historical or scientific content with **cause–effect chains**  
- Knowledge bases where entities are **highly connected**  
- Multi-document reasoning tasks that need **traceable context**  

---

### 🔄 Where This Fits in RAG100X

In earlier projects, you’ve built:

1. Flat retrieval systems with embeddings + vector stores  
2. Context enrichment techniques (CCH, CEW, HyDE, Sub-query Decomp)  
3. Reranking and evaluation pipelines  

Now, you take a **structural leap**:

> 💡 **Don’t just store knowledge — connect it.**  
> GraphRAG transforms isolated chunks into an **interconnected web of reasoning**.

---


In [None]:
# Install required packages
!pip install faiss-cpu futures langchain langchain-openai matplotlib networkx nltk numpy python-dotenv scikit-learn spacy tqdm

## 🧰 Setup & Imports for GraphRAG | What this cell actually does

This cell wires up **all the building blocks** your GraphRAG notebook will need:
- Graphs (for relationships and traversal),
- Vector search (for semantic lookup),
- NLP tools (to extract concepts & normalize terms),
- LLM glue (prompts, compression, token/cost tracking),
- Viz helpers (to show the graph and the traversal path),
- Env and resources (API keys, NLP models).

Run this once near the top of your notebook. Nothing “big” happens yet (no indexing/retrieval). You’re just **loading tools** and **configuring the environment** so later cells can create the knowledge graph and answer queries.

---

### 🔎 Big Picture (why these pieces?)

| Layer | Purpose | What it enables later |
|---|---|---|
| **Graph** | Build a knowledge graph from chunks (nodes) + relationships (edges) | Dijkstra-like traversal over “meaningful connections” |
| **Vector Store** | Fast semantic lookup from embeddings | Seed nodes for traversal; fallback answers |
| **NLP** | Concept extraction, normalization (lemmatization), tokenization | Better edges (shared concepts), robust matching |
| **LLM Utilities** | Prompting + context compression + token usage tracking | Smaller, sharper contexts + cost visibility |
| **Viz** | Matplotlib shapes & arrows | Explainable path visualization |
| **Infra** | .env keys, parallelism, progress bars | Stable config, faster preprocessing, user feedback |

---

## 📦 What each import is for (and what happens under the hood)

#### Graph & Traversal
- `networkx as nx`  
  Builds and manipulates the **knowledge graph**. Under the hood, it stores nodes/edges in efficient graph structures and provides algorithms (shortest path, centrality). In GraphRAG, you’ll:
  - Create **nodes** = text chunks,
  - Create **edges** = relationships (shared concepts, similarity),
  - Store **weights** = strength of relationships,
  - Traverse with a **priority queue** (Dijkstra-like logic).

#### Retrieval & LLM (LangChain + OpenAI)
- `FAISS` (from `langchain.vectorstores`)  
  A high-performance **similarity search** index. Internally, FAISS builds an index over embeddings so nearest neighbors can be retrieved **fast** (approximate/efficient search vs brute force).
- `RecursiveCharacterTextSplitter`  
  Splits large text into **overlapping chunks** that preserve coherence (tries to split on paragraphs → sentences → chars). Overlaps keep cross-chunk continuity.
- `PromptTemplate`  
  Parameterized prompts for LLM calls. Keeps prompts consistent and safe.
- `ContextualCompressionRetriever`, `LLMChainExtractor`  
  A **two-step retriever**: first retrieve, then **shrink** content with an LLM to keep only the **most relevant lines**. Under the hood, LangChain runs an LLM prompt that *extracts salient spans* from retrieved docs.
- `get_openai_callback`  
  A context manager that **logs tokens & cost** per LLM call. Great for **cost control** and debugging.

- `ChatOpenAI` (from `langchain_openai`)  
  LangChain’s wrapper around OpenAI chat models. Handles API calls, retries, temperature, etc.

#### Similarity & Math
- `cosine_similarity` (scikit-learn)  
  Measures **angle similarity** between embeddings (vectors). Used to:
  - Build edges between chunks (semantic closeness),
  - Re-rank neighbors during traversal.
- `numpy as np`  
  Efficient **numerical arrays** and math ops.

#### NLP (Concepts & Normalization)
- `nltk` + `word_tokenize`, `WordNetLemmatizer`  
  - **Tokenization**: split text into words/tokens.
  - **Lemmatization**: reduce words to **base forms** (e.g., “consulates” → “consulate”). This boosts **matching accuracy** when deciding if two chunks share concepts.
- `spacy`, `English`, `spacy.cli.download`  
  - SpaCy pipeline powers **entity detection** and **noun chunk extraction** (depending on the model you load).
  - `English()` gives a light-weight English pipeline; to use trained models (like `en_core_web_sm`), you’ll typically `spacy.cli.download("en_core_web_sm")` and then `spacy.load("en_core_web_sm")`.  
  **Why both NLTK + spaCy?** NLTK’s lemmatizer is simple/reliable; SpaCy gives you faster tokenization, POS tags, entities. Together they help **feature extraction for edges**.

#### Priority Queue & Parallelism
- `heapq`  
  Python’s **priority queue**. Critical for the **Dijkstra-like traversal**: always expand the most promising node next (highest score = strongest path to answer).
- `ThreadPoolExecutor`, `as_completed`  
  **Parallelize** expensive steps (embedding chunks, computing pairwise similarities, extracting concepts). This reduces preprocessing time on multi-core CPUs.

#### Progress & Visualization
- `tqdm`  
  Nice progress bars for long loops (chunking, embedding, graph building).
- `matplotlib.pyplot as plt`, `matplotlib.patches as patches`  
  Drawing the **graph** and the **traversal path**:
  - Nodes as dots/boxes,
  - Edge **color/width** = weight strength,
  - **Curved/dashed arrows** to highlight the path the algorithm actually took.

#### OS & Environment
- `os`, `sys`, `dotenv.load_dotenv()`  
  - Load **environment variables** (e.g., `OPENAI_API_KEY`) from a local `.env`.
  - Manage file paths, optional `sys.path` tweaks.
- `os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"`  
  Fixes a known OpenMP/MKL warning on some platforms (NumPy/PyTorch). Lets the notebook continue if duplicate OMP libs are detected.


---


In [None]:
import networkx as nx
from langchain.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.prompts import PromptTemplate
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import LLMChainExtractor
from langchain.callbacks import get_openai_callback

from sklearn.metrics.pairwise import cosine_similarity
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import os
import sys
from dotenv import load_dotenv
from langchain_openai import ChatOpenAI
from typing import List, Tuple, Dict
from nltk.stem import WordNetLemmatizer
from nltk.tokenize import word_tokenize
import nltk
import spacy
import heapq


from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
import numpy as np

from spacy.cli import download
from spacy.lang.en import English
# Load environment variables from a .env file
load_dotenv()

# Set the OpenAI API key environment variable
os.environ["OPENAI_API_KEY"] = os.getenv('OPENAI_API_KEY')
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

nltk.download('punkt', quiet=True)
nltk.download('wordnet', quiet=True)

### Understanding the `DocumentProcessor` Class

The `DocumentProcessor` class is designed as a **utility for preparing documents** so they can be efficiently used in a Retrieval-Augmented Generation (RAG) system. Let’s break down what it does, why it does it, and what happens under the hood at each step.

---

#### 1. Initialization (`__init__`)

When a `DocumentProcessor` object is created:

- **Text Splitter**:  
  It sets up a `RecursiveCharacterTextSplitter` that breaks long documents into smaller overlapping chunks.  
  - **Why**: Large documents cannot be directly fed into language models due to context length limits. Splitting ensures that information is captured in manageable pieces.  
  - **Under the hood**: It recursively looks for natural breakpoints (paragraphs, sentences, punctuation). If it can’t, it falls back to raw character splits. Each chunk is about 1000 characters, with 200 characters overlapping to preserve context between chunks.

- **Embeddings Model**:  
  It initializes `OpenAIEmbeddings`, which will later be used to convert text into high-dimensional numeric vectors.  
  - **Why**: Text alone cannot be searched efficiently. Converting text into embeddings allows semantic similarity search.  
  - **Under the hood**: The embedding model takes each chunk of text and maps it into a vector space (e.g., 1536 dimensions). In this space, semantically similar texts are close to each other.

---

#### 2. `process_documents(documents)`

This method takes raw documents and makes them “search-ready.”

- **Splitting Documents**:  
  The text splitter breaks each document into chunks (as explained above).  
  - Result: A list of smaller, overlapping text segments.

- **Creating a Vector Store (FAISS)**:  
  Each chunk is embedded using `OpenAIEmbeddings`, and the embeddings are stored in a FAISS index.  
  - **Why FAISS**: FAISS is a specialized library for fast similarity search over large embedding collections. Instead of scanning through every vector, FAISS uses clever indexing structures (like clustering or inverted files) to quickly find the nearest vectors.  
  - **Under the hood**: Each embedding vector is inserted into FAISS’s internal index. Later, when you query with another embedding, FAISS computes distances and retrieves the closest matches.

- **Output**:  
  - `splits`: The list of text chunks.  
  - `vector_store`: The FAISS index mapping those chunks to their embeddings.

---

#### 3. `create_embeddings_batch(texts, batch_size=32)`

This method converts a large number of texts into embeddings in smaller groups (batches).

- **Batching**:  
  Instead of sending all texts at once, it processes them in groups of 32 by default.  
  - **Why**: APIs and GPU models often have limits. Batching prevents overload and makes the process more efficient.  
  - **Under the hood**: For each batch, `embed_documents` is called, which makes an API call to the embedding model. The responses are combined into one large list.

- **Output**:  
  A NumPy array containing all embeddings.

---

#### 4. `compute_similarity_matrix(embeddings)`

This method calculates how similar each embedding is to every other embedding.

- **Cosine Similarity**:  
  It uses cosine similarity, which measures the angle between two vectors.  
  - **Why**: In embedding space, cosine similarity is a common metric because it focuses on direction rather than magnitude (two texts with similar meaning will point in a similar direction).  
  - **Under the hood**: For every pair of embeddings, it computes:  

    \[
    \text{similarity}(A,B) = \frac{A \cdot B}{||A|| \times ||B||}
    \]

  - The result is a square matrix where each cell `(i, j)` represents how similar text `i` is to text `j`.

- **Output**:  
  A 2D NumPy array (matrix) of similarity scores between all texts.

---

#### 📌 Why This Matters for RAG

1. **Chunking** ensures the LLM can handle large documents.  
2. **Embeddings + FAISS** make the document collection searchable by meaning, not just keywords.  
3. **Batch embedding** speeds up processing and avoids hitting API limits.  
4. **Similarity matrix** helps in analyzing document relationships, clustering, or debugging retrieval quality.  

Together, these steps prepare raw documents so that an LLM can later **find, retrieve, and use the most relevant context** when answering questions.

---


In [None]:
# Define the DocumentProcessor class
class DocumentProcessor:
    def __init__(self):
        """
        Initializes the DocumentProcessor with a text splitter and OpenAI embeddings.
        
        Attributes:
        - text_splitter: An instance of RecursiveCharacterTextSplitter with specified chunk size and overlap.
        - embeddings: An instance of OpenAIEmbeddings used for embedding documents.
        """
        self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
        self.embeddings = OpenAIEmbeddings()

    def process_documents(self, documents):
        """
        Processes a list of documents by splitting them into smaller chunks and creating a vector store.
        
        Args:
        - documents (list of str): A list of documents to be processed.
        
        Returns:
        - tuple: A tuple containing:
          - splits (list of str): The list of split document chunks.
          - vector_store (FAISS): A FAISS vector store created from the split document chunks and their embeddings.
        """
        splits = self.text_splitter.split_documents(documents)
        vector_store = FAISS.from_documents(splits, self.embeddings)
        return splits, vector_store

    def create_embeddings_batch(self, texts, batch_size=32):
        """
        Creates embeddings for a list of texts in batches.
        
        Args:
        - texts (list of str): A list of texts to be embedded.
        - batch_size (int, optional): The number of texts to process in each batch. Default is 32.
        
        Returns:
        - numpy.ndarray: An array of embeddings for the input texts.
        """
        embeddings = []
        for i in range(0, len(texts), batch_size):
            batch = texts[i:i+batch_size]
            batch_embeddings = self.embeddings.embed_documents(batch)
            embeddings.extend(batch_embeddings)
        return np.array(embeddings)

    def compute_similarity_matrix(self, embeddings):
        """
        Computes a cosine similarity matrix for a given set of embeddings.
        
        Args:
        - embeddings (numpy.ndarray): An array of embeddings.
        
        Returns:
        - numpy.ndarray: A cosine similarity matrix for the input embeddings.
        """
        return cosine_similarity(embeddings)

### Understanding the `KnowledgeGraph` Class

The `KnowledgeGraph` class is responsible for **turning document chunks into a graph of connected ideas**. Each document chunk becomes a **node**, and connections (edges) are formed between nodes if they are semantically similar and share concepts.  

---

#### 1. Initialization
- **Graph**: Uses `networkx.Graph()` to store nodes (document chunks) and edges (relationships).  
- **Lemmatizer**: Reduces words to their base form (e.g., "running" → "run"). Helps normalize concepts.  
- **Concept Cache**: Avoids recalculating concepts for the same text.  
- **spaCy NLP Model**: Detects named entities (people, places, orgs, etc.).  
- **Edge Threshold**: Similarity score (0.8) above which two nodes should be connected.  

---

#### 2. Building the Graph (`build_graph`)
The pipeline for constructing the graph:
1. **Add Nodes** → Each text split becomes a graph node with its content.  
2. **Create Embeddings** → Each chunk is embedded into a vector.  
3. **Extract Concepts** → Named entities (via spaCy) + general concepts (via LLM).  
4. **Add Edges** → Connect nodes if their embeddings are similar enough and they share concepts.  

---

#### 3. Adding Nodes (`_add_nodes`)
- Loops through all text splits.  
- Each split is added as a **node** with `page_content`.  

---

#### 4. Creating Embeddings (`_create_embeddings`)
- Converts each chunk’s text into embeddings using the embedding model.  
- These embeddings will later be compared for similarity.  

---

#### 5. Computing Similarities (`_compute_similarities`)
- Uses **cosine similarity** to measure how close each embedding is to every other embedding.  
- Produces a similarity matrix.  

---

#### 6. Extracting Concepts
- **Named Entities** (via spaCy): Detects proper nouns like names, organizations, and places.  
- **General Concepts** (via LLM): Extracts abstract ideas from text.  
- **Combination**: Both sets are merged into a list of concepts for each node.  
- Uses **multi-threading** for efficiency.  

---

#### 7. Adding Edges (`_add_edges`)
- For every pair of nodes:  
  - If similarity > threshold (0.8), check for **shared concepts**.  
  - Compute an **edge weight** (blend of similarity score and shared concepts).  
  - Add an edge with attributes: weight, similarity, and list of shared concepts.  

---

#### 8. Edge Weight Calculation
\[
\text{weight} = \alpha \times \text{similarity} + \beta \times \text{normalized shared concepts}
\]  
- Default: `α = 0.7`, `β = 0.3`.  
- This ensures edges reflect both semantic closeness and conceptual overlap.  

---

#### 9. Lemmatizing Concepts (`_lemmatize_concept`)
- Converts each concept into its base form.  
- Example: `"running shoes"` → `"run shoe"`.  
- Ensures different forms of the same concept link together.  

---

#### 📌 Why This Matters
The class turns raw document chunks into a **knowledge graph**, where:
- **Nodes** = pieces of text.  
- **Edges** = meaningful relationships (semantic similarity + shared ideas).  

This allows for more structured retrieval, making RAG systems not just retrieve “similar text” but also understand **conceptual connections** between documents.


In [None]:
Define the knowledge graph class
# Define the Concepts class
class Concepts(BaseModel):
    concepts_list: List[str] = Field(description="List of concepts")

# Define the KnowledgeGraph class
class KnowledgeGraph:
    def __init__(self):
        """
        Initializes the KnowledgeGraph with a graph, lemmatizer, and NLP model.
        
        Attributes:
        - graph: An instance of a networkx Graph.
        - lemmatizer: An instance of WordNetLemmatizer.
        - concept_cache: A dictionary to cache extracted concepts.
        - nlp: An instance of a spaCy NLP model.
        - edges_threshold: A float value that sets the threshold for adding edges based on similarity.
        """
        self.graph = nx.Graph()
        self.lemmatizer = WordNetLemmatizer()
        self.concept_cache = {}
        self.nlp = self._load_spacy_model()
        self.edges_threshold = 0.8

    def build_graph(self, splits, llm, embedding_model):
        """
        Builds the knowledge graph by adding nodes, creating embeddings, extracting concepts, and adding edges.
        
        Args:
        - splits (list): A list of document splits.
        - llm: An instance of a large language model.
        - embedding_model: An instance of an embedding model.
        
        Returns:
        - None
        """
        self._add_nodes(splits)
        embeddings = self._create_embeddings(splits, embedding_model)
        self._extract_concepts(splits, llm)
        self._add_edges(embeddings)

    def _add_nodes(self, splits):
        """
        Adds nodes to the graph from the document splits.
        
        Args:
        - splits (list): A list of document splits.
        
        Returns:
        - None
        """
        for i, split in enumerate(splits):
            self.graph.add_node(i, content=split.page_content)

    def _create_embeddings(self, splits, embedding_model):
        """
        Creates embeddings for the document splits using the embedding model.
        
        Args:
        - splits (list): A list of document splits.
        - embedding_model: An instance of an embedding model.
        
        Returns:
        - numpy.ndarray: An array of embeddings for the document splits.
        """
        texts = [split.page_content for split in splits]
        return embedding_model.embed_documents(texts)

    def _compute_similarities(self, embeddings):
        """
        Computes the cosine similarity matrix for the embeddings.
        
        Args:
        - embeddings (numpy.ndarray): An array of embeddings.
        
        Returns:
        - numpy.ndarray: A cosine similarity matrix for the embeddings.
        """
        return cosine_similarity(embeddings)

    def _load_spacy_model(self):
        """
        Loads the spaCy NLP model, downloading it if necessary.
        
        Args:
        - None
        
        Returns:
        - spacy.Language: An instance of a spaCy NLP model.
        """
        try:
            return spacy.load("en_core_web_sm")
        except OSError:
            print("Downloading spaCy model...")
            download("en_core_web_sm")
            return spacy.load("en_core_web_sm")

    def _extract_concepts_and_entities(self, content, llm):
        """
        Extracts concepts and named entities from the content using spaCy and a large language model.
        
        Args:
        - content (str): The content from which to extract concepts and entities.
        - llm: An instance of a large language model.
        
        Returns:
        - list: A list of extracted concepts and entities.
        """
        if content in self.concept_cache:
            return self.concept_cache[content]
        
        # Extract named entities using spaCy
        doc = self.nlp(content)
        named_entities = [ent.text for ent in doc.ents if ent.label_ in ["PERSON", "ORG", "GPE", "WORK_OF_ART"]]
        
        # Extract general concepts using LLM
        concept_extraction_prompt = PromptTemplate(
            input_variables=["text"],
            template="Extract key concepts (excluding named entities) from the following text:\n\n{text}\n\nKey concepts:"
        )
        concept_chain = concept_extraction_prompt | llm.with_structured_output(Concepts)
        general_concepts = concept_chain.invoke({"text": content}).concepts_list
        
        # Combine named entities and general concepts
        all_concepts = list(set(named_entities + general_concepts))
        
        self.concept_cache[content] = all_concepts
        return all_concepts

    def _extract_concepts(self, splits, llm):
        """
        Extracts concepts for all document splits using multi-threading.
        
        Args:
        - splits (list): A list of document splits.
        - llm: An instance of a large language model.
        
        Returns:
        - None
        """
        with ThreadPoolExecutor() as executor:
            future_to_node = {executor.submit(self._extract_concepts_and_entities, split.page_content, llm): i 
                              for i, split in enumerate(splits)}
            
            for future in tqdm(as_completed(future_to_node), total=len(splits), desc="Extracting concepts and entities"):
                node = future_to_node[future]
                concepts = future.result()
                self.graph.nodes[node]['concepts'] = concepts

    def _add_edges(self, embeddings):
        """
        Adds edges to the graph based on the similarity of embeddings and shared concepts.
        
        Args:
        - embeddings (numpy.ndarray): An array of embeddings for the document splits.
        
        Returns:
        - None
        """
        similarity_matrix = self._compute_similarities(embeddings)
        num_nodes = len(self.graph.nodes)
        
        for node1 in tqdm(range(num_nodes), desc="Adding edges"):
            for node2 in range(node1 + 1, num_nodes):
                similarity_score = similarity_matrix[node1][node2]
                if similarity_score > self.edges_threshold:
                    shared_concepts = set(self.graph.nodes[node1]['concepts']) & set(self.graph.nodes[node2]['concepts'])
                    edge_weight = self._calculate_edge_weight(node1, node2, similarity_score, shared_concepts)
                    self.graph.add_edge(node1, node2, weight=edge_weight, 
                                        similarity=similarity_score,
                                        shared_concepts=list(shared_concepts))

    def _calculate_edge_weight(self, node1, node2, similarity_score, shared_concepts, alpha=0.7, beta=0.3):
        """
        Calculates the weight of an edge based on similarity score and shared concepts.
        
        Args:
        - node1 (int): The first node.
        - node2 (int): The second node.
        - similarity_score (float): The similarity score between the nodes.
        - shared_concepts (set): The set of shared concepts between the nodes.
        - alpha (float, optional): The weight of the similarity score. Default is 0.7.
        - beta (float, optional): The weight of the shared concepts. Default is 0.3.
        
        Returns:
        - float: The calculated weight of the edge.
        """
        max_possible_shared = min(len(self.graph.nodes[node1]['concepts']), len(self.graph.nodes[node2]['concepts']))
        normalized_shared_concepts = len(shared_concepts) / max_possible_shared if max_possible_shared > 0 else 0
        return alpha * similarity_score + beta * normalized_shared_concepts

    def _lemmatize_concept(self, concept):
        """
        Lemmatizes a given concept.
        
        Args:
        - concept (str): The concept to be lemmatized.
        
        Returns:
        - str: The lemmatized concept.
        """
        return ' '.join([self.lemmatizer.lemmatize(word) for word in concept.lower().split()])

In [None]:
# Download required data files
import os
os.makedirs('data', exist_ok=True)

### Understanding the `QueryEngine` Class

The `QueryEngine` class is responsible for **answering user queries** by combining:
- **Vector search** (to fetch relevant docs quickly)  
- **Knowledge graph traversal** (to expand context using relationships)  
- **LLM reasoning** (to check if an answer is complete, or generate one if missing)  

---

#### 1. Initialization
- **Vector Store** → Stores embeddings of documents for similarity search.  
- **Knowledge Graph** → Graph of concepts & relationships between chunks.  
- **LLM** → Used for both concept extraction and final answering.  
- **Answer Check Chain** → A small LLM prompt that verifies if context provides a complete answer.  
- **Max Context Length** → Keeps context within token budget.  

---

#### 2. Answer Checking (`_check_answer`)
- Takes a **query + context**.  
- Asks the LLM: *“Does this context fully answer the query?”*  
- Returns:  
  - `is_complete` (True/False)  
  - `answer` (if complete).  

Example:  
> Query: "Who wrote Hamlet?"  
> Context: "Hamlet is a tragedy written by William Shakespeare in the early 1600s."  
✅ Complete → Answer: "William Shakespeare".  

---

#### 3. Retrieving Relevant Documents (`_retrieve_relevant_documents`)
- Uses the vector store to get top-k similar chunks.  
- Compresses them with another LLM step (removes fluff).  
- Returns a list of relevant docs to start with.  

---

#### 4. Expanding Context (`_expand_context`)
This is the **core traversal logic**, similar to Dijkstra’s algorithm but adapted for knowledge graphs.  

#### Steps:
1. **Initialize**  
   - Start from the most relevant nodes (from vector search).  
   - Push them into a priority queue (priority = inverse of similarity/connection strength).  

2. **Traversal**  
   - Always pick the node with **highest connection strength**.  
   - Add its content to `expanded_context`.  
   - Check if context now gives a complete answer.  

3. **Concept Handling**  
   - Track which concepts have been “covered”.  
   - Only expand to neighbors that introduce **new concepts**.  

4. **Neighbor Expansion**  
   - For each neighbor, calculate new priority (`1 / edge_weight`).  
   - If this path is stronger than before, update the queue.  

5. **Termination**  
   - Stop if a **complete answer** is found.  
   - If queue is empty but no complete answer → fallback to LLM to generate one.  

---

#### 5. Query Flow (`query`)
Full end-to-end process:
1. **Retrieve relevant docs** via vector store.  
2. **Expand context** by traversing knowledge graph.  
3. **Check completeness**:  
   - If complete → return answer.  
   - If incomplete → LLM synthesizes final answer from expanded context.  
4. Returns:  
   - `final_answer`  
   - `traversal_path` (nodes visited)  
   - `filtered_content` (mapped node IDs → text).  

---

#### 📌 Why This Matters
Unlike plain vector search, this approach:
- **Understands relationships** between chunks (via graph edges).  
- **Expands context meaningfully** (instead of just concatenating top-k).  
- **Knows when it has enough info** (via answer check).  
- Produces **more reliable and interpretable answers** for complex queries.  


In [None]:
# Define the AnswerCheck class
class AnswerCheck(BaseModel):
    is_complete: bool = Field(description="Whether the current context provides a complete answer to the query")
    answer: str = Field(description="The current answer based on the context, if any")

# Define the QueryEngine class
class QueryEngine:
    def __init__(self, vector_store, knowledge_graph, llm):
        self.vector_store = vector_store
        self.knowledge_graph = knowledge_graph
        self.llm = llm
        self.max_context_length = 4000
        self.answer_check_chain = self._create_answer_check_chain()

    def _create_answer_check_chain(self):
        """
        Creates a chain to check if the context provides a complete answer to the query.
        
        Args:
        - None
        
        Returns:
        - Chain: A chain to check if the context provides a complete answer.
        """
        answer_check_prompt = PromptTemplate(
            input_variables=["query", "context"],
            template="Given the query: '{query}'\n\nAnd the current context:\n{context}\n\nDoes this context provide a complete answer to the query? If yes, provide the answer. If no, state that the answer is incomplete.\n\nIs complete answer (Yes/No):\nAnswer (if complete):"
        )
        return answer_check_prompt | self.llm.with_structured_output(AnswerCheck)

    def _check_answer(self, query: str, context: str) -> Tuple[bool, str]:
        """
        Checks if the current context provides a complete answer to the query.
        
        Args:
        - query (str): The query to be answered.
        - context (str): The current context.
        
        Returns:
        - tuple: A tuple containing:
          - is_complete (bool): Whether the context provides a complete answer.
          - answer (str): The answer based on the context, if complete.
        """
        response = self.answer_check_chain.invoke({"query": query, "context": context})
        return response.is_complete, response.answer

  

    def _expand_context(self, query: str, relevant_docs) -> Tuple[str, List[int], Dict[int, str], str]:
        """
        Expands the context by traversing the knowledge graph using a Dijkstra-like approach.
        
        This method implements a modified version of Dijkstra's algorithm to explore the knowledge graph,
        prioritizing the most relevant and strongly connected information. The algorithm works as follows:

        1. Initialize:
           - Start with nodes corresponding to the most relevant documents.
           - Use a priority queue to manage the traversal order, where priority is based on connection strength.
           - Maintain a dictionary of best known "distances" (inverse of connection strengths) to each node.

        2. Traverse:
           - Always explore the node with the highest priority (strongest connection) next.
           - For each node, check if we've found a complete answer.
           - Explore the node's neighbors, updating their priorities if a stronger connection is found.

        3. Concept Handling:
           - Track visited concepts to guide the exploration towards new, relevant information.
           - Expand to neighbors only if they introduce new concepts.

        4. Termination:
           - Stop if a complete answer is found.
           - Continue until the priority queue is empty (all reachable nodes explored).

        This approach ensures that:
        - We prioritize the most relevant and strongly connected information.
        - We explore new concepts systematically.
        - We find the most relevant answer by following the strongest connections in the knowledge graph.

        Args:
        - query (str): The query to be answered.
        - relevant_docs (List[Document]): A list of relevant documents to start the traversal.

        Returns:
        - tuple: A tuple containing:
          - expanded_context (str): The accumulated context from traversed nodes.
          - traversal_path (List[int]): The sequence of node indices visited.
          - filtered_content (Dict[int, str]): A mapping of node indices to their content.
          - final_answer (str): The final answer found, if any.
        """
        # Initialize variables
        expanded_context = ""
        traversal_path = []
        visited_concepts = set()
        filtered_content = {}
        final_answer = ""
        
        priority_queue = []
        distances = {}  # Stores the best known "distance" (inverse of connection strength) to each node
        
        print("\nTraversing the knowledge graph:")
        
        # Initialize priority queue with closest nodes from relevant docs
        for doc in relevant_docs:
            # Find the most similar node in the knowledge graph for each relevant document
            closest_nodes = self.vector_store.similarity_search_with_score(doc.page_content, k=1)
            closest_node_content, similarity_score = closest_nodes[0]
            
            # Get the corresponding node in our knowledge graph
            closest_node = next(n for n in self.knowledge_graph.graph.nodes if self.knowledge_graph.graph.nodes[n]['content'] == closest_node_content.page_content)
            
            # Initialize priority (inverse of similarity score for min-heap behavior)
            priority = 1 / similarity_score
            heapq.heappush(priority_queue, (priority, closest_node))
            distances[closest_node] = priority
        
        step = 0
        while priority_queue:
            # Get the node with the highest priority (lowest distance value)
            current_priority, current_node = heapq.heappop(priority_queue)
            
            # Skip if we've already found a better path to this node
            if current_priority > distances.get(current_node, float('inf')):
                continue
            
            if current_node not in traversal_path:
                step += 1
                traversal_path.append(current_node)
                node_content = self.knowledge_graph.graph.nodes[current_node]['content']
                node_concepts = self.knowledge_graph.graph.nodes[current_node]['concepts']
                
                # Add node content to our accumulated context
                filtered_content[current_node] = node_content
                expanded_context += "\n" + node_content if expanded_context else node_content
                
                # Log the current step for debugging and visualization
                print(f"\nStep {step} - Node {current_node}:")
                print(f"Content: {node_content[:100]}...") 
                print(f"Concepts: {', '.join(node_concepts)}")
                print("-" * 50)
                
                # Check if we have a complete answer with the current context
                is_complete, answer = self._check_answer(query, expanded_context)
                if is_complete:
                    final_answer = answer
                    break
                
                # Process the concepts of the current node
                node_concepts_set = set(self.knowledge_graph._lemmatize_concept(c) for c in node_concepts)
                if not node_concepts_set.issubset(visited_concepts):
                    visited_concepts.update(node_concepts_set)
                    
                    # Explore neighbors
                    for neighbor in self.knowledge_graph.graph.neighbors(current_node):
                        edge_data = self.knowledge_graph.graph[current_node][neighbor]
                        edge_weight = edge_data['weight']
                        
                        # Calculate new distance (priority) to the neighbor
                        # Note: We use 1 / edge_weight because higher weights mean stronger connections
                        distance = current_priority + (1 / edge_weight)
                        
                        # If we've found a stronger connection to the neighbor, update its distance
                        if distance < distances.get(neighbor, float('inf')):
                            distances[neighbor] = distance
                            heapq.heappush(priority_queue, (distance, neighbor))
                            
                            # Process the neighbor node if it's not already in our traversal path
                            if neighbor not in traversal_path:
                                step += 1
                                traversal_path.append(neighbor)
                                neighbor_content = self.knowledge_graph.graph.nodes[neighbor]['content']
                                neighbor_concepts = self.knowledge_graph.graph.nodes[neighbor]['concepts']
                                
                                filtered_content[neighbor] = neighbor_content
                                expanded_context += "\n" + neighbor_content if expanded_context else neighbor_content
                                
                                # Log the neighbor node information
                                print(f"\nStep {step} - Node {neighbor} (neighbor of {current_node}):")
                                print(f"Content: {neighbor_content[:100]}...")
                                print(f"Concepts: {', '.join(neighbor_concepts)}")
                                print("-" * 50)
                                
                                # Check if we have a complete answer after adding the neighbor's content
                                is_complete, answer = self._check_answer(query, expanded_context)
                                if is_complete:
                                    final_answer = answer
                                    break
                                
                                # Process the neighbor's concepts
                                neighbor_concepts_set = set(self.knowledge_graph._lemmatize_concept(c) for c in neighbor_concepts)
                                if not neighbor_concepts_set.issubset(visited_concepts):
                                    visited_concepts.update(neighbor_concepts_set)
                
                # If we found a final answer, break out of the main loop
                if final_answer:
                    break

        # If we haven't found a complete answer, generate one using the LLM
        if not final_answer:
            print("\nGenerating final answer...")
            response_prompt = PromptTemplate(
                input_variables=["query", "context"],
                template="Based on the following context, please answer the query.\n\nContext: {context}\n\nQuery: {query}\n\nAnswer:"
            )
            response_chain = response_prompt | self.llm
            input_data = {"query": query, "context": expanded_context}
            final_answer = response_chain.invoke(input_data)

        return expanded_context, traversal_path, filtered_content, final_answer

    def query(self, query: str) -> Tuple[str, List[int], Dict[int, str]]:
        """
        Processes a query by retrieving relevant documents, expanding the context, and generating the final answer.
        
        Args:
        - query (str): The query to be answered.
        
        Returns:
        - tuple: A tuple containing:
          - final_answer (str): The final answer to the query.
          - traversal_path (list): The traversal path of nodes in the knowledge graph.
          - filtered_content (dict): The filtered content of nodes.
        """
        with get_openai_callback() as cb:
            print(f"\nProcessing query: {query}")
            relevant_docs = self._retrieve_relevant_documents(query)
            expanded_context, traversal_path, filtered_content, final_answer = self._expand_context(query, relevant_docs)
            
            if not final_answer:
                print("\nGenerating final answer...")
                response_prompt = PromptTemplate(
                    input_variables=["query", "context"],
                    template="Based on the following context, please answer the query.\n\nContext: {context}\n\nQuery: {query}\n\nAnswer:"
                )
                
                response_chain = response_prompt | self.llm
                input_data = {"query": query, "context": expanded_context}
                response = response_chain.invoke(input_data)
                final_answer = response
            else:
                print("\nComplete answer found during traversal.")
            
            print(f"\nFinal Answer: {final_answer}")
            print(f"\nTotal Tokens: {cb.total_tokens}")
            print(f"Prompt Tokens: {cb.prompt_tokens}")
            print(f"Completion Tokens: {cb.completion_tokens}")
            print(f"Total Cost (USD): ${cb.total_cost}")
        
        return final_answer, traversal_path, filtered_content

    def _retrieve_relevant_documents(self, query: str):
        """
        Retrieves relevant documents based on the query using the vector store.
        
        Args:
        - query (str): The query to be answered.
        
        Returns:
        - list: A list of relevant documents.
        """
        print("\nRetrieving relevant documents...")
        retriever = self.vector_store.as_retriever(search_type="similarity", search_kwargs={"k": 5})
        compressor = LLMChainExtractor.from_llm(self.llm)
        compression_retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever)
        return compression_retriever.invoke(query)

# Visualizer Class – Deep Explanation

The **Visualizer class** provides two main functionalities:
1. **`visualize_traversal`** → Creates a graphical visualization of the traversal path on a knowledge graph.  
2. **`print_filtered_content`** → Prints the filtered content of the nodes that were visited during traversal, in the exact order they were explored.

---

## 1. `visualize_traversal(graph, traversal_path)`

This method takes:
- A **knowledge graph** (built using NetworkX).
- A **traversal path** (a list of node IDs in the order they were visited).  

It then generates a **visual flow diagram** showing:
- Nodes.
- Edges (with weights).
- The exact traversal path, highlighted using red arrows.
- Start node (highlighted in **green**) and end node (highlighted in **red**).
- A legend and color bar for better interpretation.

### Step-by-step explanation

1. **Initialize a directed graph**  
   - Create a new directed graph (`nx.DiGraph`) and copy over all nodes + edges from the original graph.  
   - This ensures we have a clean graph specifically for visualization.

2. **Set up the figure**  
   - Use `matplotlib` to create a canvas (`fig, ax = plt.subplots(figsize=(16, 12))`).  
   - The size ensures readability of nodes, labels, and paths.

3. **Node positioning**  
   - Generate node positions using `spring_layout`.  
   - This layout spreads nodes out in a visually pleasing way (like magnets repelling each other).

4. **Draw regular edges**  
   - Edges are drawn in **blue shades**, where the intensity of the color depends on the **edge weight**.  
   - A color map (`plt.cm.Blues`) maps weights to colors.  
   - This helps us see which edges are "stronger" or more important.

5. **Draw nodes**  
   - All nodes are drawn as **light blue circles**.  
   - Size is large (`3000`) so text labels fit well.

6. **Highlight traversal path**  
   - For each consecutive pair `(start → end)` in the traversal path:  
     - Draw a **red dashed curved arrow** (`FancyArrowPatch`).  
     - This makes the traversal stand out compared to regular edges.  
   - Arrows are curved slightly (`rad=0.3`) so they don’t overlap with regular edges.

7. **Prepare node labels**  
   - Each visited node is labeled with the step number and its main concept (if available). Example: `1. Physics`.  
   - Non-visited nodes just display their concept name (if any).

8. **Highlight start and end nodes**  
   - Start node = **light green**.  
   - End node = **light coral (red-shaded)**.  
   - This makes it immediately clear where traversal began and ended.

9. **Color bar for edge weights**  
   - Adds a vertical color scale on the side to explain what the blue edge intensities mean (low weight → light blue, high weight → dark blue).

10. **Legend**  
   - Explains:
     - Blue line = regular edge.  
     - Red dashed line = traversal path.  
     - Green node = start node.  
     - Red node = end node.  

11. **Show the graph**  
   - Finally, `plt.show()` renders the graph for visualization.

---

## 2. `print_filtered_content(traversal_path, filtered_content)`

This method prints the **content of nodes** in the order they were visited.

### Step-by-step explanation:

1. **Iterate through traversal path**  
   - For each step in the traversal, fetch the node ID and its content.

2. **Print step details**  
   - Example format:  
     ```
     Step 1 - Node 3:
     Filtered Content: Einstein’s Theory of Relativity emphasizes...
     --------------------------------------------------
     ```
   - Only the **first 200 characters** of the content are shown (to keep output concise).

3. **Fallback if no content**  
   - If the node has no content, prints: `"No filtered content available"`.

This provides a **textual log** of what knowledge was encountered along the path, complementing the graphical visualization.

---

## ✅ Why is this useful?

- **For debugging**: You can see if the traversal followed the right sequence.  
- **For learning**: Helps trace knowledge flow step by step.  
- **For presentations**: The visualization makes it easier for non-technical stakeholders to understand the graph logic.  
- **For analysis**: Edge weights and node labels provide context for why a certain path was chosen.

---


In [None]:

# Import necessary libraries
import networkx as nx
import matplotlib.pyplot as plt
import matplotlib.patches as patches

# Define the Visualizer class
class Visualizer:
    @staticmethod
    def visualize_traversal(graph, traversal_path):
        """
        Visualizes the traversal path on the knowledge graph with nodes, edges, and traversal path highlighted.

        Args:
        - graph (networkx.Graph): The knowledge graph containing nodes and edges.
        - traversal_path (list of int): The list of node indices representing the traversal path.

        Returns:
        - None
        """
        traversal_graph = nx.DiGraph()
        
        # Add nodes and edges from the original graph
        for node in graph.nodes():
            traversal_graph.add_node(node)
        for u, v, data in graph.edges(data=True):
            traversal_graph.add_edge(u, v, **data)
        
        fig, ax = plt.subplots(figsize=(16, 12))
        
        # Generate positions for all nodes
        pos = nx.spring_layout(traversal_graph, k=1, iterations=50)
        
        # Draw regular edges with color based on weight
        edges = traversal_graph.edges()
        edge_weights = [traversal_graph[u][v].get('weight', 0.5) for u, v in edges]
        nx.draw_networkx_edges(traversal_graph, pos, 
                               edgelist=edges,
                               edge_color=edge_weights,
                               edge_cmap=plt.cm.Blues,
                               width=2,
                               ax=ax)
        
        # Draw nodes
        nx.draw_networkx_nodes(traversal_graph, pos, 
                               node_color='lightblue',
                               node_size=3000,
                               ax=ax)
        
        # Draw traversal path with curved arrows
        edge_offset = 0.1
        for i in range(len(traversal_path) - 1):
            start = traversal_path[i]
            end = traversal_path[i + 1]
            start_pos = pos[start]
            end_pos = pos[end]
            
            # Calculate control point for curve
            mid_point = ((start_pos[0] + end_pos[0]) / 2, (start_pos[1] + end_pos[1]) / 2)
            control_point = (mid_point[0] + edge_offset, mid_point[1] + edge_offset)
            
            # Draw curved arrow
            arrow = patches.FancyArrowPatch(start_pos, end_pos,
                                            connectionstyle=f"arc3,rad={0.3}",
                                            color='red',
                                            arrowstyle="->",
                                            mutation_scale=20,
                                            linestyle='--',
                                            linewidth=2,
                                            zorder=4)
            ax.add_patch(arrow)
        
        # Prepare labels for the nodes
        labels = {}
        for i, node in enumerate(traversal_path):
            concepts = graph.nodes[node].get('concepts', [])
            label = f"{i + 1}. {concepts[0] if concepts else ''}"
            labels[node] = label
        
        for node in traversal_graph.nodes():
            if node not in labels:
                concepts = graph.nodes[node].get('concepts', [])
                labels[node] = concepts[0] if concepts else ''
        
        # Draw labels
        nx.draw_networkx_labels(traversal_graph, pos, labels, font_size=8, font_weight="bold", ax=ax)
        
        # Highlight start and end nodes
        start_node = traversal_path[0]
        end_node = traversal_path[-1]
        
        nx.draw_networkx_nodes(traversal_graph, pos, 
                               nodelist=[start_node], 
                               node_color='lightgreen', 
                               node_size=3000,
                               ax=ax)
        
        nx.draw_networkx_nodes(traversal_graph, pos, 
                               nodelist=[end_node], 
                               node_color='lightcoral', 
                               node_size=3000,
                               ax=ax)
        
        ax.set_title("Graph Traversal Flow")
        ax.axis('off')
        
        # Add colorbar for edge weights
        sm = plt.cm.ScalarMappable(cmap=plt.cm.Blues, norm=plt.Normalize(vmin=min(edge_weights), vmax=max(edge_weights)))
        sm.set_array([])
        cbar = fig.colorbar(sm, ax=ax, orientation='vertical', fraction=0.046, pad=0.04)
        cbar.set_label('Edge Weight', rotation=270, labelpad=15)
        
        # Add legend
        regular_line = plt.Line2D([0], [0], color='blue', linewidth=2, label='Regular Edge')
        traversal_line = plt.Line2D([0], [0], color='red', linewidth=2, linestyle='--', label='Traversal Path')
        start_point = plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='lightgreen', markersize=15, label='Start Node')
        end_point = plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='lightcoral', markersize=15, label='End Node')
        legend = plt.legend(handles=[regular_line, traversal_line, start_point, end_point], loc='upper left', bbox_to_anchor=(0, 1), ncol=2)
        legend.get_frame().set_alpha(0.8)

        plt.tight_layout()
        plt.show()

    @staticmethod
    def print_filtered_content(traversal_path, filtered_content):
        """
        Prints the filtered content of visited nodes in the order of traversal.

        Args:
        - traversal_path (list of int): The list of node indices representing the traversal path.
        - filtered_content (dict of int: str): A dictionary mapping node indices to their filtered content.

        Returns:
        - None
        """
        print("\nFiltered content of visited nodes in order of traversal:")
        for i, node in enumerate(traversal_path):
            print(f"\nStep {i + 1} - Node {node}:")
            print(f"Filtered Content: {filtered_content.get(node, 'No filtered content available')[:200]}...")  # Print first 200 characters
            print("-" * 50)

### Explanation of the `GraphRAG` Class

The `GraphRAG` class is a high-level orchestrator that ties together all the components needed to build a **graph-based Retrieval-Augmented Generation (RAG) system**. Instead of only relying on vector similarity search (like vanilla RAG), it also builds a **knowledge graph** from documents, enabling more structured reasoning, path traversal, and visualization of how the system derives answers.

---

#### 🔑 Key Responsibilities of `GraphRAG`

1. **Initialization (`__init__`)**  
   - **LLM (`ChatOpenAI`)**: The large language model used for generating responses, extracting relationships, and reasoning over graph nodes.  
   - **Embedding model (`OpenAIEmbeddings`)**: Transforms document chunks into dense vectors for similarity search.  
   - **Document Processor (`DocumentProcessor`)**: Handles chunking of documents and creation of embeddings/vector store.  
   - **Knowledge Graph (`KnowledgeGraph`)**: Builds a structured graph of entities and relationships extracted from the documents.  
   - **Query Engine (`QueryEngine`)**: Handles the process of interpreting queries, retrieving information from both the vector store and the knowledge graph, and producing a final answer. Initially set to `None` until documents are processed.  
   - **Visualizer (`Visualizer`)**: Generates visual representations of how the query traversed the graph, showing the reasoning path.

   👉 Under the hood: At this stage, the system is just setting up its "tools" but hasn’t yet ingested any documents.

---

2. **Processing Documents (`process_documents`)**  
   - Splits the raw documents into manageable chunks using the `DocumentProcessor`.  
   - Creates embeddings and stores them in a **FAISS vector store** for fast similarity search.  
   - Builds a **knowledge graph** where nodes represent entities/concepts and edges represent relationships between them (extracted using the LLM).  
   - Instantiates the `QueryEngine`, which knows how to combine vector retrieval + graph traversal for answering queries.

   👉 Under the hood: This step converts **unstructured text** into two structured representations:  
   - **Vector space** (good for fuzzy semantic similarity).  
   - **Graph structure** (good for explicit reasoning and connections).  

   By maintaining both, the system gains **breadth (semantic coverage)** and **depth (logical reasoning through graph paths)**.

---

3. **Query Handling (`query`)**  
   - Accepts a user query as input.  
   - Passes it to the `QueryEngine`, which:  
     - Retrieves semantically similar chunks from the vector store.  
     - Explores the knowledge graph to find connected entities/paths that might be relevant.  
     - Combines this evidence to generate a final response with the LLM.  
   - Collects additional outputs:  
     - **Traversal Path**: The sequence of nodes/edges the system followed in the graph.  
     - **Filtered Content**: Subset of document chunks most relevant to the query.  
   - If a traversal path exists, calls the `Visualizer` to plot the graph traversal, making the reasoning explainable.  

   👉 Under the hood: This is where the **RAG pipeline activates** — the system fuses vector-based recall with graph-based reasoning. Instead of being a black-box LLM response, you can **see the path** of how the system reasoned.

---

## ⚙️ Why This Design?

- **Vector Store Only (Vanilla RAG)**: Great for retrieving relevant chunks but limited in reasoning over structured relationships.  
- **Graph Only**: Great for explicit knowledge reasoning but brittle if the graph misses relevant details.  
- **Graph + Vector Hybrid (GraphRAG)**: Combines the **flexibility of embeddings** with the **structure of graphs**, allowing:  
  - Better retrieval quality.  
  - Explainability (graph traversal path).  
  - Structured knowledge grounding.  

In short: **`GraphRAG` enables a more powerful, transparent, and structured RAG system.**


In [None]:

class GraphRAG:
    def __init__(self):
        """
        Initializes the GraphRAG system with components for document processing, knowledge graph construction,
        querying, and visualization.
        
        Attributes:
        - llm: An instance of a large language model (LLM) for generating responses.
        - embedding_model: An instance of an embedding model for document embeddings.
        - document_processor: An instance of the DocumentProcessor class for processing documents.
        - knowledge_graph: An instance of the KnowledgeGraph class for building and managing the knowledge graph.
        - query_engine: An instance of the QueryEngine class for handling queries (initialized as None).
        - visualizer: An instance of the Visualizer class for visualizing the knowledge graph traversal.
        """
        self.llm = ChatOpenAI(temperature=0, model_name="gpt-4o-mini", max_tokens=4000)
        self.embedding_model = OpenAIEmbeddings()
        self.document_processor = DocumentProcessor()
        self.knowledge_graph = KnowledgeGraph()
        self.query_engine = None
        self.visualizer = Visualizer()

    def process_documents(self, documents):
        """
        Processes a list of documents by splitting them into chunks, embedding them, and building a knowledge graph.
        
        Args:
        - documents (list of str): A list of documents to be processed.
        
        Returns:
        - None
        """
        splits, vector_store = self.document_processor.process_documents(documents)
        self.knowledge_graph.build_graph(splits, self.llm, self.embedding_model)
        self.query_engine = QueryEngine(vector_store, self.knowledge_graph, self.llm)

    def query(self, query: str):
        """
        Handles a query by retrieving relevant information from the knowledge graph and visualizing the traversal path.
        
        Args:
        - query (str): The query to be answered.
        
        Returns:
        - str: The response to the query.
        """
        response, traversal_path, filtered_content = self.query_engine.query(query)
        
        if traversal_path:
            self.visualizer.visualize_traversal(self.knowledge_graph.graph, traversal_path)
        else:
            print("No traversal path to visualize.")
        
        return response

### Define documents path

In [None]:
path = "data/Understanding_Climate_Change.pdf"

### Create a graph RAG instance


In [None]:

graph_rag = GraphRAG()

### Process the documents and create the graph


In [None]:

graph_rag.process_documents(documents)

### Input a query and get the retrieved information from the graph RAG


In [None]:

query = "what is the main cause of climate change?"
response = graph_rag.query(query)

---

## 📘 Summary & Credits

This notebook is based on the excellent open-source repository [RAG_Techniques by NirDiamant](https://github.com/NirDiamant/RAG_Techniques).  
I referred to that work to understand how the pipeline is structured and then reimplemented the same concept in a **fully self-contained** way, but using recent models — as part of my personal learning journey.

The purpose of this notebook is purely **educational**:  
- To deepen my understanding of Retrieval-Augmented Generation systems  
- To keep a clean, trackable log of what I’ve built and learned  
- And to serve as a future reference for myself or others starting from scratch

To support that, I’ve added clear, concise markdowns throughout the notebook — explaining *why* each package was installed, *why* each line of code exists, and *how* each component fits into the overall RAG pipeline. It’s designed to help anyone (including my future self) grasp the **how** and the **why**, not just the **what**.

## 🔍 Why Use Graph-Based Retrieval in RAG?

Traditional vector retrieval focuses only on semantic similarity between chunks.  
While effective, it often misses **relationships between concepts** that span multiple chunks or documents.  

**Graph-Based Retrieval** addresses this by:  
- 🧩 Capturing **entities and relationships** (knowledge graph) from documents  
- 🔗 Allowing **multi-hop reasoning** across connected concepts  
- 🎯 Providing more **contextually coherent answers** by traversing related nodes  

---

## 🧠 What’s New in This Version?

This implementation includes:  

- 🧱 A **GraphRAG class** that integrates LLM, embeddings, document processing, graph building, querying, and visualization  
- 📚 **KnowledgeGraph** to represent documents as nodes & edges (entities + relationships)  
- 🔍 **QueryEngine** that combines vector search with graph traversal for enriched retrieval  
- 🎨 **Visualizer** to display traversal paths, making reasoning transparent  

---

## 📈 Inferences & Key Takeaways

- ✅ Ideal for domains with **interconnected facts** (research papers, legal docs, technical manuals)  
- 🧠 Goes beyond *chunk-level similarity* by **reasoning over structured relationships**  
- ⚡ Adds interpretability with traversal visualization (you can see “how” the model answered)  
- 🔍 Bridges the gap between **unstructured text** and **structured reasoning**  

---

## 🚀 What Could Be Added Next?

- 📊 Evaluate GraphRAG vs. pure vector retrieval on **accuracy and hallucination reduction**  
- 🧪 Extend to **hybrid retrieval** (graph + BM25 + embeddings) for robustness  
- 🔗 Support **temporal or causal edges** (not just entity links)  
- 🧠 Add **fallback to global retrieval** if graph traversal fails to find an answer  
- ⚙️ Explore **graph algorithms** (PageRank, centrality) to prioritize influential nodes  



## 💡 Final Word

This notebook is part of my larger personal project: **RAG100x** — a challenge to build and log my journney in RAG from 0 100 in the coming months.

It’s not built to impress — it’s built to **progress**.  
Everything here is structured to enable **daily iteration**, focused experimentation, and clean documentation.

If you're exploring RAG from first principles, feel free to use this as a scaffold for your own builds. And of course — check out the original repository for broader implementations and ideas.