# GraphRAG Implementation

This notebook implements a Graph-based RAG system using Neptune for graph storage and hybrid search combining graph and vector retrieval.

## Configuration Options

### Document Processing
- chunk_size: Number of words per chunk (default: 500)
- chunk_overlap: Number of overlapping words between chunks (default: 50)
- enable_chunking: Whether to split documents into chunks (default: True)

### Graph Construction
- min_entity_freq: Minimum frequency for entity inclusion (default: 2)
- max_relation_distance: Maximum token distance for relationship extraction (default: 10)
- confidence_threshold: Minimum confidence for extracted relations (default: 0.5)

### Hybrid Search
- k_graph: Number of graph-based results to retrieve (default: 5)
- k_vector: Number of vector-based results to retrieve (default: 3)
- alpha: Weight for combining graph and vector scores (default: 0.7)

### Neptune Settings
- instance_type: Neptune instance type (default: 'db.r6g.xlarge')
- enable_audit: Enable audit logging (default: True)

### API Settings
- max_retries: Maximum number of retry attempts (default: 5)
- min_delay: Minimum delay between retries in seconds (default: 1)
- max_delay: Maximum delay between retries in seconds (default: 60)

## Prerequisites
- Run setup.ipynb first to configure environment
- Neptune cluster must be configured
- SpaCy model must be downloaded

In [None]:
import os
import sys
import json
import boto3
import spacy
import numpy as np
from pathlib import Path
from typing import List, Dict, Any, Optional
from tqdm.auto import tqdm

# Add project root to path for imports
project_root = Path("../..").resolve()
if str(project_root) not in sys.path:
    sys.path.append(str(project_root))

# Import utilities
from utils.aws.opensearch_utils import OpenSearchManager
from utils.aws.neptune_utils import NeptuneManager, NeptuneGraph
from utils.metrics.bedrock_llm import BedrockLLM

In [None]:
class GraphRAG:
    """
    Graph-based RAG implementation using Neptune for graph storage and hybrid search.
    """
    
    def __init__(
        self,
        index_name: str,
        chunk_size: int = 500,
        chunk_overlap: int = 50,
        enable_chunking: bool = True,
        min_entity_freq: int = 2,
        max_relation_distance: int = 10,
        confidence_threshold: float = 0.5,
        k_graph: int = 5,
        k_vector: int = 3,
        alpha: float = 0.7,
        instance_type: str = "db.r6g.xlarge",
        enable_audit: bool = True,
        max_retries: int = 5,
        min_delay: float = 1.0,
        max_delay: float = 60.0
    ):
        """
        Initialize GraphRAG with configuration parameters.
        
        Args:
            index_name: Name for the OpenSearch index (vector store)
            chunk_size: Number of words per chunk
            chunk_overlap: Number of overlapping words between chunks
            enable_chunking: Whether to split documents into chunks
            min_entity_freq: Minimum frequency for entity inclusion
            max_relation_distance: Maximum token distance for relationships
            confidence_threshold: Minimum confidence for extracted relations
            k_graph: Number of graph-based results to retrieve
            k_vector: Number of vector-based results to retrieve
            alpha: Weight for combining graph and vector scores (0-1)
            instance_type: Neptune instance type
            enable_audit: Enable audit logging
            max_retries: Maximum retry attempts
            min_delay: Minimum retry delay in seconds
            max_delay: Maximum retry delay in seconds
        """
        self.index_name = index_name
        
        # Document processing config
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap
        self.enable_chunking = enable_chunking
        
        # Graph construction config
        self.min_entity_freq = min_entity_freq
        self.max_relation_distance = max_relation_distance
        self.confidence_threshold = confidence_threshold
        
        # Search config
        self.k_graph = k_graph
        self.k_vector = k_vector
        self.alpha = alpha
        
        # Neptune config
        self.instance_type = instance_type
        self.enable_audit = enable_audit
        
        # API config
        self.max_retries = max_retries
        self.min_delay = min_delay
        self.max_delay = max_delay
        
        # Initialize components
        self._init_nlp()
        self._init_graph_store()
        self._init_vector_store()
        self._init_llm()

In [None]:
    def _init_nlp(self):
        """Initialize SpaCy for entity and relation extraction."""
        try:
            self.nlp = spacy.load("en_core_web_sm")
        except OSError:
            print("Downloading SpaCy model...")
            os.system("python -m spacy download en_core_web_sm")
            self.nlp = spacy.load("en_core_web_sm")
    
    def _init_graph_store(self):
        """Initialize Neptune graph store connection."""
        # Set up Neptune cluster
        self.neptune_manager = NeptuneManager(
            cluster_name=f"graph-rag-{self.index_name}",
            instance_type=self.instance_type,
            cleanup_enabled=True
        )
        endpoint = self.neptune_manager.setup_cluster()
        
        # Initialize graph interface
        self.graph = NeptuneGraph(endpoint)
    
    def _init_vector_store(self):
        """Initialize OpenSearch vector store connection."""
        self.opensearch = OpenSearchManager(
            domain_name=f"graph-rag-{self.index_name}",
            cleanup_enabled=True
        )
        endpoint = self.opensearch.setup_domain()
        
        # Create index with vector search settings
        index_settings = {
            "settings": {
                "index": {
                    "knn": True,
                    "knn.algo_param.ef_search": 512
                }
            },
            "mappings": {
                "properties": {
                    "content": {"type": "text"},
                    "vector": {
                        "type": "knn_vector",
                        "dimension": 1024,
                        "method": {
                            "name": "hnsw",
                            "space_type": "cosinesimil",
                            "engine": "nmslib",
                            "parameters": {
                                "ef_construction": 512,
                                "m": 16
                            }
                        }
                    }
                }
            }
        }
        
        self.opensearch.create_index(
            index_name=self.index_name,
            settings=index_settings
        )
    
    def _init_llm(self):
        """Initialize Bedrock LLM for response generation."""
        self.llm = BedrockLLM()

In [None]:
    def _extract_entities_relations(self, text: str) -> Dict[str, Any]:
        """
        Extract entities and relations from text using SpaCy.
        
        Args:
            text: Input text to process
            
        Returns:
            Dictionary containing extracted entities and relations
        """
        doc = self.nlp(text)
        
        # Extract entities
        entities = []
        for ent in doc.ents:
            if ent.label_ in ["PERSON", "ORG", "GPE", "DATE", "EVENT"]:
                entities.append({
                    "text": ent.text,
                    "label": ent.label_,
                    "start": ent.start_char,
                    "end": ent.end_char
                })
        
        # Extract relations
        relations = []
        for token in doc:
            if token.dep_ == "ROOT":
                for child in token.children:
                    if child.dep_ in ["nsubj", "dobj"]:
                        relations.append({
                            "subject": child.text,
                            "predicate": token.text,
                            "object": next((c.text for c in token.children 
                                          if c.dep_ in ["dobj", "pobj"]), None)
                        })
        
        return {
            "entities": entities,
            "relations": relations
        }

In [None]:
    def _store_graph_data(self, doc_id: str, graph_data: Dict[str, Any]):
        """
        Store extracted entities and relations in Neptune.
        
        Args:
            doc_id: Document identifier
            graph_data: Dictionary containing entities and relations
        """
        # Create document vertex
        doc_vertex_id = self.graph.add_vertex(
            label="Document",
            properties={"id": doc_id},
            id=doc_id
        )
        
        # Add entities
        entity_ids = {}
        for entity in graph_data["entities"]:
            # Create unique ID for entity
            entity_id = f"{entity['text']}_{entity['label']}"
            
            # Add entity vertex if it doesn't exist
            if entity_id not in entity_ids:
                entity_ids[entity_id] = self.graph.add_vertex(
                    label=entity["label"],
                    properties={
                        "text": entity["text"],
                        "label": entity["label"]
                    },
                    id=entity_id
                )
            
            # Link entity to document
            self.graph.add_edge(
                from_id=doc_vertex_id,
                to_id=entity_ids[entity_id],
                label="CONTAINS",
                properties={
                    "start": entity["start"],
                    "end": entity["end"]
                }
            )
        
        # Add relations
        for relation in graph_data["relations"]:
            if relation["object"]:
                # Create relation edge between entities
                subject_matches = self.graph.get_vertices(
                    properties={"text": relation["subject"]}
                )
                object_matches = self.graph.get_vertices(
                    properties={"text": relation["object"]}
                )
                
                if subject_matches and object_matches:
                    self.graph.add_edge(
                        from_id=subject_matches[0]["id"],
                        to_id=object_matches[0]["id"],
                        label=relation["predicate"].upper(),
                        properties={
                            "document": doc_id,
                            "confidence": self.confidence_threshold
                        }
                    )
    
    def _vector_search(self, query: str) -> List[Dict[str, Any]]:
        """
        Perform vector similarity search using OpenSearch.
        
        Args:
            query: Search query
            
        Returns:
            List of retrieved documents with scores
        """
        # Get query embedding from Bedrock
        query_embedding = self.llm.get_embedding(query)
        
        # Perform k-NN search
        search_query = {
            "size": self.k_vector,
            "query": {
                "knn": {
                    "vector": {
                        "vector": query_embedding,
                        "k": self.k_vector
                    }
                }
            }
        }
        
        results = self.opensearch.search(
            index=self.index_name,
            body=search_query
        )
        
        # Format results
        vector_results = []
        for hit in results["hits"]["hits"]:
            vector_results.append({
                "id": hit["_id"],
                "score": hit["_score"],
                "content": hit["_source"]["content"],
                "source": "vector"
            })
        
        return vector_results
    
    def _hybrid_search(self, query: str) -> List[Dict[str, Any]]:
        """
        Perform hybrid search combining graph and vector retrieval.
        
        Args:
            query: Search query
            
        Returns:
            List of retrieved documents with scores
        """
        # Extract entities from query
        query_graph = self._extract_entities_relations(query)
        
        # Get graph-based results
        graph_results = []
        for entity in query_graph["entities"]:
            # Find matching entity vertices
            matches = self.graph.get_vertices(
                properties={"text": entity["text"]}
            )
            
            for match in matches:
                # Get connected documents
                docs = self.graph.get_neighbors(
                    vertex_id=match["id"],
                    direction="in",
                    edge_label="CONTAINS",
                    limit=self.k_graph
                )
                
                for doc in docs:
                    graph_results.append({
                        "id": doc["id"],
                        "score": 1.0,  # TODO: Implement graph scoring
                        "source": "graph"
                    })
        
        # Get vector-based results
        vector_results = self._vector_search(query)
        
        # Combine results
        combined_results = self._merge_results(graph_results, vector_results)
        
        return combined_results

In [None]:
    def _merge_results(
        self,
        graph_results: List[Dict[str, Any]],
        vector_results: List[Dict[str, Any]]
    ) -> List[Dict[str, Any]]:
        """
        Merge graph and vector search results using weighted scoring.
        
        Args:
            graph_results: Results from graph search
            vector_results: Results from vector search
            
        Returns:
            Combined and re-ranked results
        """
        # Combine all results
        all_results = {}
        
        # Add graph results
        for result in graph_results:
            if result["id"] not in all_results:
                all_results[result["id"]] = {
                    "id": result["id"],
                    "graph_score": result["score"],
                    "vector_score": 0.0
                }
        
        # Add vector results
        for result in vector_results:
            if result["id"] not in all_results:
                all_results[result["id"]] = {
                    "id": result["id"],
                    "graph_score": 0.0,
                    "vector_score": result["score"]
                }
            else:
                all_results[result["id"]]["vector_score"] = result["score"]
        
        # Calculate combined scores
        results = []
        for doc_id, scores in all_results.items():
            combined_score = (
                self.alpha * scores["graph_score"] +
                (1 - self.alpha) * scores["vector_score"]
            )
            results.append({
                "id": doc_id,
                "score": combined_score,
                "graph_score": scores["graph_score"],
                "vector_score": scores["vector_score"]
            })
        
        # Sort by combined score
        results.sort(key=lambda x: x["score"], reverse=True)
        
        return results[:self.k_vector]
    
    def query(self, query: str) -> Dict[str, Any]:
        """
        Process a query using graph-augmented retrieval.
        
        Args:
            query: User query
            
        Returns:
            Dictionary containing response and context
        """
        # Get relevant documents
        results = self._hybrid_search(query)
        
        # Extract graph context
        graph_context = []
        for result in results:
            # Get entities and relations for document
            doc_entities = self.graph.get_neighbors(
                vertex_id=result["id"],
                direction="out",
                edge_label="CONTAINS"
            )
            
            doc_relations = []
            for entity in doc_entities:
                relations = self.graph.get_edges(
                    properties={"document": result["id"]}
                )
                doc_relations.extend(relations)
            
            graph_context.append({
                "doc_id": result["id"],
                "entities": doc_entities,
                "relations": doc_relations
            })
        
        # Format prompt with graph context
        prompt = self._format_prompt(query, results, graph_context)
        
        # Generate response
        response = self.llm.generate(prompt)
        
        return {
            "response": response,
            "context": results,
            "graph_context": graph_context
        }

In [None]:
    def _format_prompt(
        self,
        query: str,
        results: List[Dict[str, Any]],
        graph_context: List[Dict[str, Any]]
    ) -> str:
        """
        Format prompt with retrieved context and graph information.
        
        Args:
            query: Original query
            results: Retrieved documents
            graph_context: Graph relationships
            
        Returns:
            Formatted prompt string
        """
        # Format document context
        doc_context = "\n\n".join(r["content"] for r in results)
        
        # Format graph context
        graph_sections = []
        for ctx in graph_context:
            # Format entities
            entities = [f"{e['text']} ({e['label']})" for e in ctx["entities"]]
            
            # Format relations
            relations = [
                f"{r['from']} {r['label']} {r['to']}"
                for r in ctx["relations"]
            ]
            
            section = f"Document {ctx['doc_id']}:\n"
            section += "Entities: " + ", ".join(entities) + "\n"
            section += "Relations: " + ", ".join(relations)
            graph_sections.append(section)
        
        graph_text = "\n\n".join(graph_sections)
        
        prompt = f"""Use the following information to answer the question.

Document Context:
{doc_context}

Graph Context:
{graph_text}

Question: {query}

Answer:"""
        
        return prompt