# GraphRAG-based Question Answering

This notebook demonstrates using the langchain-graphrag library to implement a knowledge graph-based RAG system.

The approach involves:
1. Text extraction and splitting into units
2. Graph generation using entity and relationship extraction
3. Graph community detection
4. Question answering using either local or global search strategies

In [1]:
import os
from pathlib import Path
from typing import cast
from dotenv import load_dotenv
import pandas as pd

from langchain_chroma import Chroma
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_community.cache import SQLiteCache
from langchain_text_splitters import TokenTextSplitter

from langchain_graphrag.indexing import SimpleIndexer, TextUnitExtractor
from langchain_graphrag.indexing.artifacts_generation import (
    CommunitiesReportsArtifactsGenerator,
    EntitiesArtifactsGenerator, 
    RelationshipsArtifactsGenerator,
    TextUnitsArtifactsGenerator
)
from langchain_graphrag.indexing.graph_clustering import HierarchicalLeidenCommunityDetector
from langchain_graphrag.indexing.graph_generation import (
    EntityRelationshipExtractor,
    EntityRelationshipDescriptionSummarizer,
    GraphGenerator,
    GraphsMerger
)
from langchain_graphrag.indexing.report_generation import (
    CommunityReportGenerator,
    CommunityReportWriter
)
from langchain_graphrag.types.graphs.community import CommunityLevel
from langchain_graphrag.utils import TiktokenCounter

# Load environment variables
load_dotenv()

# Setup paths
CACHE_DIR = Path("cache")
VECTOR_STORE_DIR = Path("vector_stores") 
ARTIFACTS_DIR = Path("artifacts")

for p in [CACHE_DIR, VECTOR_STORE_DIR, ARTIFACTS_DIR]:
    p.mkdir(parents=True, exist_ok=True)

## Configure Environment and Models

Set up the required models and environment variables.

In [2]:
# Create the LLMs
er_llm = ChatOpenAI(
    model="gpt-4o",
    temperature=0.0,
    api_key=os.environ["OPENAI_API_KEY"],
    cache=SQLiteCache(str(CACHE_DIR / "openai_cache.db")),
)

es_llm = ChatOpenAI(
    model="gpt-4o", 
    temperature=0.0,
    api_key=os.environ["OPENAI_API_KEY"],
    cache=SQLiteCache(str(CACHE_DIR / "openai_cache.db")),
)

# Create embeddings
embeddings = OpenAIEmbeddings(
    model="text-embedding-3-small",
    api_key=os.environ["OPENAI_API_KEY"]
)

# Create vector store for entities
entities_vector_store = Chroma(
    collection_name="sec-10q-entities",
    persist_directory=str(VECTOR_STORE_DIR),
    embedding_function=embeddings
)

# Setup text splitter and extractor
text_splitter = TokenTextSplitter(chunk_size=512, chunk_overlap=24)
text_unit_extractor = TextUnitExtractor(text_splitter=text_splitter)

## Initialize Graph Components

Set up the components needed for graph generation and processing.

In [3]:
# Entity relationship extraction and summarization
entity_extractor = EntityRelationshipExtractor.build_default(llm=er_llm)
entity_summarizer = EntityRelationshipDescriptionSummarizer.build_default(llm=es_llm)

# Graph generator
graph_generator = GraphGenerator(
    er_extractor=entity_extractor,
    graphs_merger=GraphsMerger(),
    er_description_summarizer=entity_summarizer
)

# Community detector
community_detector = HierarchicalLeidenCommunityDetector()

## Initialize Artifacts Generators

Set up components for generating various artifacts from the graph.

In [4]:
# Create artifacts generators
entities_artifacts_generator = EntitiesArtifactsGenerator(
    entities_vector_store=entities_vector_store
)

relationships_artifacts_generator = RelationshipsArtifactsGenerator()

report_generator = CommunityReportGenerator.build_default(llm=er_llm)
report_writer = CommunityReportWriter()

communities_report_artifacts_generator = CommunitiesReportsArtifactsGenerator(
    report_generator=report_generator,
    report_writer=report_writer
)

text_units_artifacts_generator = TextUnitsArtifactsGenerator()

## Load and Process Documents

Load the input text and split it into manageable units.

In [None]:
# Load and process the documents
from langchain_community.document_loaders.pdf import PyPDFLoader

documents = []
docs_path = Path("../../data/sec-10-q/docs")

# Load PDF documents
for filename in os.listdir(docs_path):
    if filename.endswith(".pdf"):
        file_path = docs_path / filename
        try:
            docs = PyPDFLoader(str(file_path)).load()
            documents.extend(docs)
            print(f"Processed: {filename}")
        except Exception as e:
            print(f"Error processing {filename}: {str(e)}")

## Create Indexer and Generate Artifacts

Initialize the indexer and process the documents to generate all required artifacts.

In [None]:
# Create the indexer
indexer = SimpleIndexer(
    text_unit_extractor=text_unit_extractor,
    graph_generator=graph_generator,
    community_detector=community_detector,
    entities_artifacts_generator=entities_artifacts_generator,
    relationships_artifacts_generator=relationships_artifacts_generator,
    text_units_artifacts_generator=text_units_artifacts_generator,
    communities_report_artifacts_generator=communities_report_artifacts_generator
)

# Run indexing
artifacts = indexer.run(documents)

## Local Search Example

Demonstrate using the local search capability for answering specific questions.

In [None]:
from langchain_graphrag.query.local_search import (
    LocalSearch,
    LocalSearchPromptBuilder,
    LocalSearchRetriever,
)
from langchain_graphrag.query.local_search.context_builders import ContextBuilder
from langchain_graphrag.query.local_search.context_selectors import ContextSelector

# Create components for local search
context_selector = ContextSelector.build_default(
    entities_vector_store=entities_vector_store,
    entities_top_k=10,
    community_level=cast(CommunityLevel, 2)
)

context_builder = ContextBuilder.build_default(
    token_counter=TiktokenCounter(),
)

retriever = LocalSearchRetriever(
    context_selector=context_selector,
    context_builder=context_builder,
    artifacts=artifacts,
)

local_search = LocalSearch(
    prompt_builder=LocalSearchPromptBuilder(show_references=True),
    llm=er_llm,
    retriever=retriever
)

search_chain = local_search()

## Global Search Example

Demonstrate using the global search capability for broader questions about the document.

In [None]:
from langchain_graphrag.query.global_search import GlobalSearch
from langchain_graphrag.query.global_search.community_weight_calculator import (
    CommunityWeightCalculator
)
from langchain_graphrag.query.global_search.key_points_aggregator import (
    KeyPointsAggregator,
    KeyPointsAggregatorPromptBuilder,
    KeyPointsContextBuilder,
)
from langchain_graphrag.query.global_search.key_points_generator import (
    CommunityReportContextBuilder,
    KeyPointsGenerator,
    KeyPointsGeneratorPromptBuilder,
)

# Create components for global search
report_context_builder = CommunityReportContextBuilder(
    community_level=cast(CommunityLevel, 2),
    weight_calculator=CommunityWeightCalculator(),
    artifacts=artifacts,
    token_counter=TiktokenCounter(),
)

kp_generator = KeyPointsGenerator(
    llm=er_llm,
    prompt_builder=KeyPointsGeneratorPromptBuilder(show_references=True),
    context_builder=report_context_builder,
)

kp_aggregator = KeyPointsAggregator(
    llm=er_llm,
    prompt_builder=KeyPointsAggregatorPromptBuilder(show_references=True),
    context_builder=KeyPointsContextBuilder(
        token_counter=TiktokenCounter(),
    ),
)

global_search = GlobalSearch(
    kp_generator=kp_generator,
    kp_aggregator=kp_aggregator
)