In [2]:
import os
from pathlib import Path

import pandas as pd

from graphrag.config.enums import ModelType
from graphrag.config.models.drift_search_config import DRIFTSearchConfig
from graphrag.config.models.language_model_config import LanguageModelConfig
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
from graphrag.language_model.manager import ModelManager
from graphrag.query.indexer_adapters import (
    read_indexer_entities,
    read_indexer_relationships,
    read_indexer_report_embeddings,
    read_indexer_reports,
    read_indexer_text_units,
)
from graphrag.query.structured_search.drift_search.drift_context import (
    DRIFTSearchContextBuilder,
)
from graphrag.query.structured_search.drift_search.search import DRIFTSearch
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.vector_stores.lancedb import LanceDBVectorStore

INPUT_DIR = "../output"
LANCEDB_URI = f"{INPUT_DIR}/lancedb"

COMMUNITY_REPORT_TABLE = "community_reports"
COMMUNITY_TABLE = "communities"
ENTITY_TABLE = "entities"
RELATIONSHIP_TABLE = "relationships"
COVARIATE_TABLE = "covariates"
TEXT_UNIT_TABLE = "text_units"
COMMUNITY_LEVEL = 2


# read nodes table to get community and degree data
entity_df = pd.read_parquet(f"{INPUT_DIR}/{ENTITY_TABLE}.parquet")
community_df = pd.read_parquet(f"{INPUT_DIR}/{COMMUNITY_TABLE}.parquet")

print(f"Entity df columns: {entity_df.columns}")

entities = read_indexer_entities(entity_df, community_df, COMMUNITY_LEVEL)

# load description embeddings to an in-memory lancedb vectorstore
# to connect to a remote db, specify url and port values.
description_embedding_store = LanceDBVectorStore(
    vector_store_schema_config=VectorStoreSchemaConfig(
        index_name="default-entity-description"
    ),
)
description_embedding_store.connect(db_uri=LANCEDB_URI)

full_content_embedding_store = LanceDBVectorStore(
    vector_store_schema_config=VectorStoreSchemaConfig(
        index_name="default-community-full_content"
    )
)
full_content_embedding_store.connect(db_uri=LANCEDB_URI)

print(f"Entity count: {len(entity_df)}")
entity_df.head()

relationship_df = pd.read_parquet(f"{INPUT_DIR}/{RELATIONSHIP_TABLE}.parquet")
relationships = read_indexer_relationships(relationship_df)

print(f"Relationship count: {len(relationship_df)}")
relationship_df.head()

text_unit_df = pd.read_parquet(f"{INPUT_DIR}/{TEXT_UNIT_TABLE}.parquet")
text_units = read_indexer_text_units(text_unit_df)

print(f"Text unit records: {len(text_unit_df)}")
text_unit_df.head()

Entity df columns: Index(['id', 'human_readable_id', 'title', 'type', 'description',
       'text_unit_ids', 'frequency', 'degree', 'x', 'y'],
      dtype='object')
Entity count: 125
Relationship count: 226
Text unit records: 9


Unnamed: 0,id,human_readable_id,text,n_tokens,document_ids,entity_ids,relationship_ids,covariate_ids
0,2d3026bc93e628d6b4d0071969a23abfff15eb2fbef9dc...,0,## Table Name: int_customer_metrics\n\n## Over...,562,[09aa19a17e6833f9da5e6d737dee4cc058bdd4e7115a9...,"[23c98631-fae8-4444-bd78-31b588bb8765, c412219...","[7fb116e3-9ae7-457f-821c-1d18afca9432, 999e961...",[]
1,604cccc9579ea935009fab92e2679124521f0f79753c1e...,1,## Table Name: stg_orders\n\n## Overview \nTh...,544,[269af84a5666aac1bd4d20e69aacafaa145408e340416...,"[50c7358d-5b61-4ef3-935d-b5a736840f61, eec5774...","[73e60c02-f6e2-4bf3-87c5-0d96c935ebdf, ad93292...",[]
2,279a49ce979aead42c66127c82eb1cde49c26fba1e2026...,2,## Table Name: rpt_product_trends\n\n## Overvi...,483,[280164b5603c6fb16468df43ddcc1b809a7a3e15b6aec...,"[50c9e1fd-e06a-4435-a113-df5e756d0daf, 18961e9...",[55ffe4be-4f7c-4053-a283-74ee4c360138],[]
3,15b8b71df2aa360d584e721a581a580689833ea3115d09...,3,## Table Name: stg_customers\n\n## Overview \...,460,[5e2c8f518f025cedca97fe98b771b0b0355b1d122ae19...,"[fb773319-3550-4144-bd3c-f85171c50d40, 50c7358...","[73e60c02-f6e2-4bf3-87c5-0d96c935ebdf, 9749630...",[]
4,ad8960df886c9a66a1ba02ef31b0ad1718b9f9ffffb75c...,4,## Table Name: int_product_performance\n\n## O...,511,[5fad47491c98209419812f307d47d9c73185c857b35d9...,"[50c9e1fd-e06a-4435-a113-df5e756d0daf, eec5774...","[c3b89705-fc3b-4bad-862e-1f9a3c0158ca, 27f48aa...",[]


In [3]:
from graphrag.config.enums import ModelType
from graphrag.config.models.language_model_config import LanguageModelConfig
from graphrag.language_model.manager import ModelManager
from graphrag.tokenizer.get_tokenizer import get_tokenizer

api_key = os.environ["GEMINI_API_KEY"]

chat_config = LanguageModelConfig(
    api_key=api_key,
    auth_type="api_key",
    model_supports_json=True,
    # api_base="http://localhost:11434/v1",
    type=ModelType.Chat,
    model_provider="gemini",
    model="gemini-2.5-flash-lite",
    max_retries=20,
    temperature=0.0,
)
chat_model = ModelManager().get_or_create_chat_model(
    name="local_search",
    model_type=ModelType.Chat,
    config=chat_config,
)

embedding_config = LanguageModelConfig(
    api_key=api_key,
    type=ModelType.Embedding,
    model_provider="gemini",
    model="gemini-embedding-001",
    max_retries=20,
    # api_base="http://localhost:11434/v1",
)

text_embedder = ModelManager().get_or_create_embedding_model(
    name="local_search_embedding",
    model_type=ModelType.Embedding,
    config=embedding_config,
)

tokenizer = get_tokenizer(chat_config)

In [4]:
def read_community_reports(
    input_dir: str,
    community_report_table: str = COMMUNITY_REPORT_TABLE,
):
    """Embeds the full content of the community reports and saves the DataFrame with embeddings to the output path."""
    input_path = Path(input_dir) / f"{community_report_table}.parquet"
    return pd.read_parquet(input_path)


report_df = read_community_reports(INPUT_DIR)
reports = read_indexer_reports(
    report_df,
    community_df,
    COMMUNITY_LEVEL,
    content_embedding_col="full_content_embeddings",
)
read_indexer_report_embeddings(reports, full_content_embedding_store)

In [5]:
drift_params = DRIFTSearchConfig(
    temperature=0,
    max_tokens=12_000,
    primer_folds=1,
    drift_k_followups=3,
    n_depth=3,
    n=1,
)

context_builder = DRIFTSearchContextBuilder(
    model=chat_model,
    text_embedder=text_embedder,
    entities=entities,
    relationships=relationships,
    reports=reports,
    entity_text_embeddings=description_embedding_store,
    text_units=text_units,
    tokenizer=tokenizer,
    config=drift_params,
)

search = DRIFTSearch(
    model=chat_model, context_builder=context_builder, tokenizer=tokenizer
)

In [None]:
resp = await search.search("")

  0%|          | 0/3 [00:00<?, ?it/s]Reached token limit - reverting to previous context state
Reached token limit - reverting to previous context state
Reached token limit - reverting to previous context state
 33%|███▎      | 1/3 [00:03<00:07,  3.94s/it]No follow-up actions found for response: {}
 67%|██████▋   | 2/3 [00:04<00:02,  2.07s/it]No answer found for query: What is the date range covered by the data in staging.stg_orders?
No follow-up actions found for response: {}
No follow-up actions for action: What is the schema and data types for the `raw.orders` and `raw.order_items` tables?
No follow-up actions for action: What is the date range covered by the data in staging.stg_orders?
  0%|          | 0/3 [00:00<?, ?it/s]Reached token limit - reverting to previous context state
Reached token limit - reverting to previous context state
Reached token limit - reverting to previous context state
No answer found for query: What is the relationship between `int_customer_metrics` and the

In [12]:
resp.response

'To identify customers who have placed more than 10 orders, you should query the `CUSTOMER_ORDERS` model. This model aggregates order activity per customer and calculates metrics such as `total_orders`.\n\nHere is the SQL query to retrieve customers with more than 10 orders:\n\n```sql\nSELECT\n    customer_id,\n    total_orders,\n    total_spent,\n    avg_order_value,\n    last_order_date\nFROM\n    CUSTOMER_ORDERS\nWHERE\n    total_orders > 10;\n```\n\nThis query selects all columns from the `CUSTOMER_ORDERS` model and filters the results to include only those records where the `total_orders` metric is greater than 10. The `CUSTOMER_ORDERS` model is derived from the `STAGING.STG_ORDERS` table, which contains granular transaction details [Data: CUSTOMER_ORDERS (record ids), STAGING.STG_ORDERS (record ids)].\n\nAdditionally, the `COALESCE` function is used in models like `INT_CUSTOMER_METRICS` to replace `NULL` values with 0 for metrics such as `total_orders`. This ensures that the `tot