In [1]:
import os
from pathlib import Path

import pandas as pd

from graphrag.config.enums import ModelType
from graphrag.config.models.drift_search_config import DRIFTSearchConfig
from graphrag.config.models.language_model_config import LanguageModelConfig
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
from graphrag.language_model.manager import ModelManager
from graphrag.query.indexer_adapters import (
    read_indexer_entities,
    read_indexer_relationships,
    read_indexer_report_embeddings,
    read_indexer_reports,
    read_indexer_text_units,
)
from graphrag.query.structured_search.drift_search.drift_context import (
    DRIFTSearchContextBuilder,
)
from graphrag.query.structured_search.drift_search.search import DRIFTSearch
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.vector_stores.lancedb import LanceDBVectorStore

INPUT_DIR = "../output"
LANCEDB_URI = f"{INPUT_DIR}/lancedb"

COMMUNITY_REPORT_TABLE = "community_reports"
COMMUNITY_TABLE = "communities"
ENTITY_TABLE = "entities"
RELATIONSHIP_TABLE = "relationships"
COVARIATE_TABLE = "covariates"
TEXT_UNIT_TABLE = "text_units"
COMMUNITY_LEVEL = 2


# read nodes table to get community and degree data
entity_df = pd.read_parquet(f"{INPUT_DIR}/{ENTITY_TABLE}.parquet")
community_df = pd.read_parquet(f"{INPUT_DIR}/{COMMUNITY_TABLE}.parquet")

print(f"Entity df columns: {entity_df.columns}")

entities = read_indexer_entities(entity_df, community_df, COMMUNITY_LEVEL)

# load description embeddings to an in-memory lancedb vectorstore
# to connect to a remote db, specify url and port values.
description_embedding_store = LanceDBVectorStore(
    vector_store_schema_config=VectorStoreSchemaConfig(
        index_name="default-entity-description"
    ),
)
description_embedding_store.connect(db_uri=LANCEDB_URI)

full_content_embedding_store = LanceDBVectorStore(
    vector_store_schema_config=VectorStoreSchemaConfig(
        index_name="default-community-full_content"
    )
)
full_content_embedding_store.connect(db_uri=LANCEDB_URI)

print(f"Entity count: {len(entity_df)}")
entity_df.head()

relationship_df = pd.read_parquet(f"{INPUT_DIR}/{RELATIONSHIP_TABLE}.parquet")
relationships = read_indexer_relationships(relationship_df)

print(f"Relationship count: {len(relationship_df)}")
relationship_df.head()

text_unit_df = pd.read_parquet(f"{INPUT_DIR}/{TEXT_UNIT_TABLE}.parquet")
text_units = read_indexer_text_units(text_unit_df)

print(f"Text unit records: {len(text_unit_df)}")
text_unit_df.head()

Entity df columns: Index(['id', 'human_readable_id', 'title', 'type', 'description',
       'text_unit_ids', 'frequency', 'degree', 'x', 'y'],
      dtype='object')
Entity count: 2033
Relationship count: 2761
Text unit records: 63


Unnamed: 0,id,human_readable_id,text,n_tokens,document_ids,entity_ids,relationship_ids,covariate_ids
0,ff10b7b5be1317511cc3907906969758360dec82ea361a...,0,## Table Name: REPORTING_DEV.DBT_MAHMUDNABI_gt...,1200,[00b79e7cc5f646a669965928ff21d560f53f2d2ca46d7...,"[b8d3f062-af19-4698-9357-9b434c98a683, a4db786...","[bd1b3d23-b456-42e9-854c-ff052adadd56, bea9bc9...",[]
1,51bca52aa57117b2edc8766bb29693c83e090fe7c5c7de...,1,– Final RSAC calculation in USD. \n\n## Logi...,493,[00b79e7cc5f646a669965928ff21d560f53f2d2ca46d7...,"[a4db786a-ff10-4877-b90a-21b8f24abd35, aba8aef...","[fa75a23e-e96a-418d-89a5-1a6257095279, 1c2c369...",[]
2,6a36014cd02265a1ee85377d17176e124289e1382c9d83...,2,## Table Name: REPORTING_DEV.DBT_MAHMUDNABI_gt...,1200,[065e17268d0619f2b6281ba451e91b863f08f8fe4e64a...,"[e5852cdb-4b00-4246-9a21-74713e18ec80, ab74df7...","[e58f0353-9bd0-4a30-afe3-f0f11fe4b24b, 37bbe9d...",[]
3,7809e630d7f90df516b6d6acf64702bbe19eada9ffb6dc...,3,ABI_snapshots.sfdc_opportunity_line_item_analy...,1200,[065e17268d0619f2b6281ba451e91b863f08f8fe4e64a...,"[0f560add-8d09-4dcc-be85-7cbd68a11296, 13b9e57...","[3ea8d334-1a92-4444-addf-1561e6770c30, 9f4e2ec...",[]
4,c4b3a498009191e3694cfed96a8f92b256b701adafd452...,4,_DEV.DBT_MAHMUDNABI_snapshots.sfdc_opportunity...,203,[065e17268d0619f2b6281ba451e91b863f08f8fe4e64a...,"[acecc78b-aa58-4d04-bfe5-b264ce9b8dc2, 1f716c3...","[fec5fe9c-6e42-4d15-a258-0463541831b1, f13aba4...",[]


In [3]:
from graphrag.config.enums import ModelType
from graphrag.config.models.language_model_config import LanguageModelConfig
from graphrag.language_model.manager import ModelManager
from graphrag.tokenizer.get_tokenizer import get_tokenizer

api_key = os.environ["GEMINI_API_KEY"]

chat_config = LanguageModelConfig(
    api_key=api_key,
    auth_type="api_key",
    model_supports_json=True,
    # api_base="http://localhost:11434/v1",
    type=ModelType.Chat,
    model_provider="gemini",
    model="gemini-2.5-flash-lite",
    max_retries=20,
    temperature=0.0,
)
chat_model = ModelManager().get_or_create_chat_model(
    name="local_search",
    model_type=ModelType.Chat,
    config=chat_config,
)

embedding_config = LanguageModelConfig(
    api_key=api_key,
    type=ModelType.Embedding,
    model_provider="gemini",
    model="gemini-embedding-001",
    max_retries=20,
    # api_base="http://localhost:11434/v1",
)

text_embedder = ModelManager().get_or_create_embedding_model(
    name="local_search_embedding",
    model_type=ModelType.Embedding,
    config=embedding_config,
)

tokenizer = get_tokenizer(chat_config)

In [None]:
def read_community_reports(
    input_dir: str,
    community_report_table: str = COMMUNITY_REPORT_TABLE,
):
    """Embeds the full content of the community reports and saves the DataFrame with embeddings to the output path."""
    input_path = Path(input_dir) / f"{community_report_table}.parquet"
    return pd.read_parquet(input_path)


report_df = read_community_reports(INPUT_DIR)
reports = read_indexer_reports(
    report_df,
    community_df,
    COMMUNITY_LEVEL,
    content_embedding_col="full_content_embeddings",
)
read_indexer_report_embeddings(reports, full_content_embedding_store)

AttributeError: 'NoneType' object has no attribute 'search'

In [5]:
drift_params = DRIFTSearchConfig(
    temperature=0,
    max_tokens=12_000,
    primer_folds=1,
    drift_k_followups=3,
    n_depth=3,
    n=1,
)

context_builder = DRIFTSearchContextBuilder(
    model=chat_model,
    text_embedder=text_embedder,
    entities=entities,
    relationships=relationships,
    reports=reports,
    entity_text_embeddings=description_embedding_store,
    text_units=text_units,
    tokenizer=tokenizer,
    config=drift_params,
)

search = DRIFTSearch(
    model=chat_model, context_builder=context_builder, tokenizer=tokenizer
)

NameError: name 'reports' is not defined

In [None]:
resp = await search.search("")

  0%|          | 0/3 [00:00<?, ?it/s]Reached token limit - reverting to previous context state
Reached token limit - reverting to previous context state
Reached token limit - reverting to previous context state
 33%|███▎      | 1/3 [00:03<00:07,  3.94s/it]No follow-up actions found for response: {}
 67%|██████▋   | 2/3 [00:04<00:02,  2.07s/it]No answer found for query: What is the date range covered by the data in staging.stg_orders?
No follow-up actions found for response: {}
No follow-up actions for action: What is the schema and data types for the `raw.orders` and `raw.order_items` tables?
No follow-up actions for action: What is the date range covered by the data in staging.stg_orders?
  0%|          | 0/3 [00:00<?, ?it/s]Reached token limit - reverting to previous context state
Reached token limit - reverting to previous context state
Reached token limit - reverting to previous context state
No answer found for query: What is the relationship between `int_customer_metrics` and the

In [12]:
resp.response

'To identify customers who have placed more than 10 orders, you should query the `CUSTOMER_ORDERS` model. This model aggregates order activity per customer and calculates metrics such as `total_orders`.\n\nHere is the SQL query to retrieve customers with more than 10 orders:\n\n```sql\nSELECT\n    customer_id,\n    total_orders,\n    total_spent,\n    avg_order_value,\n    last_order_date\nFROM\n    CUSTOMER_ORDERS\nWHERE\n    total_orders > 10;\n```\n\nThis query selects all columns from the `CUSTOMER_ORDERS` model and filters the results to include only those records where the `total_orders` metric is greater than 10. The `CUSTOMER_ORDERS` model is derived from the `STAGING.STG_ORDERS` table, which contains granular transaction details [Data: CUSTOMER_ORDERS (record ids), STAGING.STG_ORDERS (record ids)].\n\nAdditionally, the `COALESCE` function is used in models like `INT_CUSTOMER_METRICS` to replace `NULL` values with 0 for metrics such as `total_orders`. This ensures that the `tot