# Microsoft's GraphRAG demo
GraphRAG minimal end-to-end example:
1) Prepare a workspace and minimal settings.yaml
2) Build the index (same as `graphrag index --root ...`)
3) Run queries:
   - Global Search (dataset-wide, uses community reports)
   - Local Search (entity-centric, mixes KG + passages)

In [None]:
import os
import textwrap
from pathlib import Path

import pandas as pd
import tiktoken

# --- GraphRAG imports (API + utils) ---
import graphrag.api as api
from graphrag.config.load_config import load_config
from graphrag.vector_stores.lancedb import LanceDBVectorStore
from graphrag.vector_stores.base import VectorStoreDocument
from graphrag.config.enums import ModelType
from graphrag.config.models.language_model_config import LanguageModelConfig
from graphrag.language_model.manager import ModelManager
from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey

# Local search (structured search) components
from graphrag.query.structured_search.local_search.search import LocalSearch
from graphrag.query.structured_search.local_search.mixed_context import (
    LocalSearchMixedContext,
)
from graphrag.query.indexer_adapters import (
    read_indexer_entities,
    read_indexer_relationships,
    read_indexer_reports,
    read_indexer_text_units,
    read_indexer_covariates,
)

In [None]:
# =============================================================================
# 1) WORKSPACE PREPARATION
# =============================================================================

# Community level to use for global/local search.
# It determines the level of detail while forming communities in the KG index.
# Must match your indexed outputs.
LEVEL: int = 2

# Project layout (root folders for GraphRAG)
ROOT = Path("./graphrag_ws").resolve()
INPUT = ROOT / "input"
OUTPUT = ROOT / "output"
CACHE = ROOT / "cache"
LANCEDB = ROOT / "lancedb"
LANCEDB_URI = LANCEDB.as_posix()

# Ensure folders exist
ROOT.mkdir(parents=True, exist_ok=True)
INPUT.mkdir(exist_ok=True)
OUTPUT.mkdir(exist_ok=True)
CACHE.mkdir(exist_ok=True)
LANCEDB.mkdir(exist_ok=True)

In [None]:
# Provide a tiny input sample (you can drop .txt files in ./input instead)
sample = textwrap.dedent(
    """
    Charles Dickens wrote "A Christmas Carol". Scrooge is visited by three ghosts.
    Bob Cratchit works for Scrooge. Tiny Tim is Bob Cratchit's son.
    """
).strip()
(INPUT / "demo.txt").write_text(sample, encoding="utf-8")

In [None]:
# Minimal settings.yaml (OpenAI). For Azure OpenAI, switch model types & fields accordingly.
# Config docs: https://microsoft.github.io/graphrag/config/yaml/
settings = f"""
models:
  default_chat_model:
    api_key: ${{GRAPHRAG_API_KEY}}
    type: openai_chat
    model: gpt-4o
    model_supports_json: true
  default_embedding_model:
    api_key: ${{GRAPHRAG_API_KEY}}
    type: openai_embedding
    model: text-embedding-3-large

input:
  type: file
  base_dir: input
  file_type: text

chunks:
  size: 1200
  overlap: 150

output:
  type: file
  base_dir: output

cache:
  type: file
  base_dir: cache

vector_store:
  default_vector_store:
    type: lancedb
    db_uri: {LANCEDB_URI}
    container_name: default
"""
(ROOT / "settings.yaml").write_text(settings.strip(), encoding="utf-8")

In [None]:
# =============================================================================
# 2) INDEX BUILD (equivalent to `graphrag index --root ...`)
# =============================================================================
async def build_index():
    """
    Load settings + environment, then run the GraphRAG index workflows.
    Prints a per-workflow status line and returns the loaded config object.

    Returns
    -------
    cfg : GraphRAG config object
    """
    cfg = load_config(ROOT)
    run_results = await api.build_index(config=cfg)
    for wf in run_results:
        print(f"[INDEX] {wf.workflow}: {'OK' if not wf.errors else 'ERROR'}")
        if wf.errors:
            for e in wf.errors:
                print("   ->", e)
    return cfg

In [None]:
# =============================================================================
# Helper: ensure entity description embeddings exist in LanceDB
# =============================================================================
def _ensure_entity_description_embeddings(
    entities_df: pd.DataFrame,
    store: LanceDBVectorStore,
    embedder,
) -> None:
    """
    Create (if missing) and populate a LanceDB collection for entity description embeddings.

    Why this is needed:
    - LocalSearch can use a vector store of *entity descriptions* to retrieve entities by
      semantic similarity (e.g., match a query to "Scrooge" description).
    - Index builds sometimes already include `description_embedding`. If not present,
      we compute embeddings on the fly.

    Parameters
    ----------
    entities_df : pd.DataFrame
        DataFrame produced by the indexer (entities.parquet).
    store : LanceDBVectorStore
        A connected LanceDB vector store pointing to the collection we want to (over)write.
    embedder : Embedding model (from ModelManager)
        Must have an `embed(text: str) -> List[float]` method.

    Notes
    -----
    - This function overwrites the collection contents for a clean demo.
      In production, prefer upserts/merges and versioning.
    """
    docs = []

    # Normalize column names to be resilient to case changes
    cols = {c.lower(): c for c in entities_df.columns}
    col_id = cols.get("id", "id")
    col_title = cols.get("title", "title")
    col_desc = cols.get("description", "description")
    col_desc_emb = cols.get("description_embedding")  # may not exist

    for _, row in entities_df.iterrows():
        ent_id = str(row[col_id])
        title = str(row.get(col_title, "") or "")
        desc = str(row.get(col_desc, "") or "")
        text = (desc or title).strip()
        if not text:
            continue

        vec = None
        if col_desc_emb and row.get(col_desc_emb) is not None:
            vec = row[col_desc_emb]
        if vec is None:
            # Compute embedding (sync). LocalSearch also uses sync LLM/embedding calls.
            vec = embedder.embed(text)

        docs.append(
            VectorStoreDocument(
                id=ent_id,
                text=text,
                vector=vec,
                attributes={"title": title},
            )
        )

    if docs:
        # Create or overwrite the collection with these documents
        store.load_documents(documents=docs, overwrite=True)

In [None]:
# =============================================================================
# 3) QUERIES: Global + Local
# =============================================================================
async def run_queries(cfg) -> None:
    """
    Demonstrates:
    - Global Search over community reports
    - Local Search using a mixed context (KG + text) and entity description embeddings
    """
    # --- Load index artifacts (DataFrames) produced by the index build ---
    entities_df = pd.read_parquet(OUTPUT / "entities.parquet")
    communities_df = pd.read_parquet(OUTPUT / "communities.parquet")
    community_reports_df = pd.read_parquet(OUTPUT / "community_reports.parquet")
    relationship_df = pd.read_parquet(OUTPUT / "relationships.parquet")
    text_unit_df = pd.read_parquet(OUTPUT / "text_units.parquet")

    # Optional covariates
    try:
        covariate_df = pd.read_parquet(OUTPUT / "covariates.parquet")
    except FileNotFoundError:
        covariate_df = None

    # ---------------------------------------------------------------------
    # 3a) GLOBAL SEARCH (dataset-wide, uses community reports)
    # ---------------------------------------------------------------------
    response_glob, context_glob = await api.global_search(
        config=cfg,
        entities=entities_df,
        communities=communities_df,
        community_reports=community_reports_df,
        community_level=LEVEL,
        dynamic_community_selection=False,
        response_type="Multiple Paragraphs",
        query="What are the main topics covered in the text?",
    )
    print("\n=== GLOBAL SEARCH ===")
    print(response_glob)

    # ---------------------------------------------------------------------
    # 3b) LOCAL SEARCH (entity-centric; mixes KG + passages)
    #
    # Pipeline overview:
    # 1) Convert indexer DataFrames into typed adapters (expected by LocalSearch)
    # 2) Prepare an entity description embedding store (LanceDB)
    # 3) Create a MixedContext builder that combines:
    #       - community reports (summaries)
    #       - text units (passages)
    #       - entities and relationships (KG)
    #       - optional covariates (claims, etc.)
    #       - entity description embeddings (for entity recall)
    # 4) Run a question against LocalSearch and print the response
    # ---------------------------------------------------------------------
    entities = read_indexer_entities(entities_df, communities_df, LEVEL)
    relationships = read_indexer_relationships(relationship_df)
    reports = read_indexer_reports(community_reports_df, communities_df, LEVEL)
    text_units = read_indexer_text_units(text_unit_df)
    covariates = (
        {"claims": read_indexer_covariates(covariate_df)} if covariate_df is not None else None
    )

    # --- Models and token encoder ---
    api_key = os.environ["GRAPHRAG_API_KEY"]

    # Use the same model names as in settings.yaml
    llm_model = "gpt-4o"
    embedding_model = "text-embedding-3-large"

    # Prepare LanceDB store for entity description embeddings
    description_embedding_store = LanceDBVectorStore(
        collection_name="entity_description_embeddings"  # same default used by indexers
    )
    description_embedding_store.connect(db_uri=LANCEDB_URI)

    # Embedding model (via ModelManager)
    embed_config = LanguageModelConfig(
        api_key=api_key,
        type=ModelType.OpenAIEmbedding,
        model=embedding_model,
        max_retries=20,
    )
    mm = ModelManager()
    text_embedder = mm.get_or_create_embedding_model(
        name="local_search_embedding",
        model_type=ModelType.OpenAIEmbedding,
        config=embed_config,
    )

    # Token encoder for prompt budgeting / truncation
    try:
        token_encoder = tiktoken.encoding_for_model(llm_model)
    except Exception:
        # Fallback if the model name is unknown to tiktoken
        token_encoder = tiktoken.get_encoding("cl100k_base")

    # Select chat model type from loaded config
    model_cfg = cfg.models["default_chat_model"]
    chat_type = str(model_cfg.type).lower()
    chat_model_name = model_cfg.model

    if chat_type == "openai_chat":
        mt = ModelType.OpenAIChat
        chat_config = LanguageModelConfig(
            api_key=api_key,
            type=mt,
            model=chat_model_name,
            max_retries=20,
        )
    elif chat_type == "azure_openai_chat":
        mt = ModelType.AzureOpenAIChat
        # For Azure, you must have these fields in settings.yaml's default_chat_model
        chat_config = LanguageModelConfig(
            api_key=api_key,
            type=mt,
            model=chat_model_name,  # optional depending on your setup
            api_base=model_cfg.api_base,
            api_version=model_cfg.api_version,
            deployment_name=model_cfg.deployment_name,
            max_retries=20,
        )
    else:
        raise ValueError(f"Unsupported chat model type: {chat_type}")

    chat_model = mm.get_or_create_chat_model(
        name="local_search_chat",
        model_type=mt,
        config=chat_config,
    )

    # --- Ensure the LanceDB collection exists and is queryable ---
    if getattr(description_embedding_store, "document_collection", None) is None:
        print("[INFO] Creating entity description embedding collection in LanceDB…")
        _ensure_entity_description_embeddings(
            entities_df=entities_df,
            store=description_embedding_store,
            embedder=text_embedder,
        )

    # Quick probe: a similarity search should succeed now
    probe_vec = text_embedder.embed("Scrooge")
    _ = description_embedding_store.similarity_search_by_vector(probe_vec, k=1)
    print("[OK] Entity description embedding collection is queryable.")

    # --- Build mixed context and run the local search ---
    ctx_builder = LocalSearchMixedContext(
        community_reports=reports,
        text_units=text_units,
        entities=entities,
        relationships=relationships,
        covariates=covariates,  # pass None if not available
        entity_text_embeddings=description_embedding_store,
        # Most builds use the internal entity ID as the vector key:
        embedding_vectorstore_key=EntityVectorStoreKey.ID,
        # If you indexed by entity title instead, use: EntityVectorStoreKey.TITLE
        text_embedder=text_embedder,
        token_encoder=token_encoder,
    )

    local = LocalSearch(model=chat_model, context_builder=ctx_builder)
    print(f"LocalSearch object ready: {local}")

    result_local = await local.search(
        query="Who is Scrooge and what are his key relationships?"
    )

    print("\n=== LOCAL SEARCH ===")
    print(result_local.response)

In [None]:
# =============================================================================
# Entry point
# =============================================================================
async def main():
    """
    Full pipeline runner:
      1) Build index from ./input into ./output
      2) Run both Global and Local searches
    """
    cfg = await build_index()
    await run_queries(cfg)

In [None]:
await main()