In [1]:
import os
from pathlib import Path

import pandas as pd

from graphrag.config.enums import ModelType
from graphrag.config.models.drift_search_config import DRIFTSearchConfig
from graphrag.config.models.language_model_config import LanguageModelConfig
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
from graphrag.language_model.manager import ModelManager
from graphrag.query.indexer_adapters import (
    read_indexer_entities,
    read_indexer_relationships,
    read_indexer_report_embeddings,
    read_indexer_reports,
    read_indexer_text_units,
)
from graphrag.query.structured_search.drift_search.drift_context import (
    DRIFTSearchContextBuilder,
)
from graphrag.query.structured_search.drift_search.search import DRIFTSearch
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.vector_stores.lancedb import LanceDBVectorStore

INPUT_DIR = "../output"
LANCEDB_URI = f"{INPUT_DIR}/lancedb"

COMMUNITY_REPORT_TABLE = "community_reports"
COMMUNITY_TABLE = "communities"
ENTITY_TABLE = "entities"
RELATIONSHIP_TABLE = "relationships"
COVARIATE_TABLE = "covariates"
TEXT_UNIT_TABLE = "text_units"
COMMUNITY_LEVEL = 2


# read nodes table to get community and degree data
entity_df = pd.read_parquet(f"{INPUT_DIR}/{ENTITY_TABLE}.parquet")
community_df = pd.read_parquet(f"{INPUT_DIR}/{COMMUNITY_TABLE}.parquet")

print(f"Entity df columns: {entity_df.columns}")

entities = read_indexer_entities(entity_df, community_df, COMMUNITY_LEVEL)

# load description embeddings to an in-memory lancedb vectorstore
# to connect to a remote db, specify url and port values.
description_embedding_store = LanceDBVectorStore(
    vector_store_schema_config=VectorStoreSchemaConfig(
        index_name="default-entity-description"
    ),
)
description_embedding_store.connect(db_uri=LANCEDB_URI)

full_content_embedding_store = LanceDBVectorStore(
    vector_store_schema_config=VectorStoreSchemaConfig(
        index_name="default-community-full_content"
    )
)
full_content_embedding_store.connect(db_uri=LANCEDB_URI)

print(f"Entity count: {len(entity_df)}")
entity_df.head()

relationship_df = pd.read_parquet(f"{INPUT_DIR}/{RELATIONSHIP_TABLE}.parquet")
relationships = read_indexer_relationships(relationship_df)

print(f"Relationship count: {len(relationship_df)}")
relationship_df.head()

text_unit_df = pd.read_parquet(f"{INPUT_DIR}/{TEXT_UNIT_TABLE}.parquet")
text_units = read_indexer_text_units(text_unit_df)

print(f"Text unit records: {len(text_unit_df)}")
text_unit_df.head()

Entity df columns: Index(['id', 'human_readable_id', 'title', 'type', 'description',
       'text_unit_ids', 'frequency', 'degree', 'x', 'y'],
      dtype='object')
Entity count: 54
Relationship count: 67
Text unit records: 9


Unnamed: 0,id,human_readable_id,text,n_tokens,document_ids,entity_ids,relationship_ids,covariate_ids
0,95eb3a39a1a78090ce34808c97a555a7e673e848db50e6...,0,## Overview \nThis model aggregates daily ord...,502,[0059107dd9400f66f52d38454f45f7737dd6b24741e63...,,,[]
1,28d87dab888dbffc333d5a43db76429c740346bc58021d...,1,## Overview \nThis model calculates the month...,474,[0ce4475c61656693c5c816b19570f5c3407d97d498acd...,"[8d8d9b6e-5153-4977-9b4c-0a178cf404dd, adb0fbf...","[0070f973-4d00-4525-8cc4-b73efab93f07, f427ba0...",[]
2,8ef0d23ed4892a59e9440eedd3d2679f77b0d283ff5b99...,2,## Overview \nThis model aggregates order act...,554,[2e5f96e339bbe668d209dc36f0598baa8cf79042f4cf9...,,,[]
3,7c518b9a0d3935bd033643f865fb7006583514716ed4a5...,3,## Overview \nThis model cleans and normalize...,536,[3c45e10176b3ab121aace825d2d98a2318b06477ba4dc...,,,[]
4,b488c06b06c8129a942f0ce6868d5bdae68a8f23043276...,4,## Overview \nThis model aggregates sales per...,529,[4e721ac6c35ebc7547fea377d410d93b54fbf38fa0a58...,,,[]


In [2]:
from graphrag.config.enums import ModelType
from graphrag.config.models.language_model_config import LanguageModelConfig
from graphrag.language_model.manager import ModelManager
from graphrag.tokenizer.get_tokenizer import get_tokenizer

api_key = os.environ["GRAPHRAG_API_KEY"]

chat_config = LanguageModelConfig(
    api_key=api_key,
    auth_type="api_key",
    model_supports_json=True,
    api_base="http://localhost:11434/v1",
    type=ModelType.Chat,
    model_provider="openai",
    model="qwen3:14b",
    max_retries=20,
    temperature=0.0,
)
chat_model = ModelManager().get_or_create_chat_model(
    name="local_search",
    model_type=ModelType.Chat,
    config=chat_config,
)

embedding_config = LanguageModelConfig(
    api_key=api_key,
    type=ModelType.Embedding,
    model_provider="openai",
    model="nomic-embed-text",
    max_retries=20,
    api_base="http://localhost:11434/v1",
)

text_embedder = ModelManager().get_or_create_embedding_model(
    name="local_search_embedding",
    model_type=ModelType.Embedding,
    config=embedding_config,
)

tokenizer = get_tokenizer(chat_config)

In [3]:
def read_community_reports(
    input_dir: str,
    community_report_table: str = COMMUNITY_REPORT_TABLE,
):
    """Embeds the full content of the community reports and saves the DataFrame with embeddings to the output path."""
    input_path = Path(input_dir) / f"{community_report_table}.parquet"
    return pd.read_parquet(input_path)


report_df = read_community_reports(INPUT_DIR)
reports = read_indexer_reports(
    report_df,
    community_df,
    COMMUNITY_LEVEL,
    content_embedding_col="full_content_embeddings",
)
read_indexer_report_embeddings(reports, full_content_embedding_store)

In [4]:
drift_params = DRIFTSearchConfig(
    temperature=0,
    max_tokens=12_000,
    primer_folds=1,
    drift_k_followups=3,
    n_depth=3,
    n=1,
)

context_builder = DRIFTSearchContextBuilder(
    model=chat_model,
    text_embedder=text_embedder,
    entities=entities,
    relationships=relationships,
    reports=reports,
    entity_text_embeddings=description_embedding_store,
    text_units=text_units,
    tokenizer=tokenizer,
    config=drift_params,
)

search = DRIFTSearch(
    model=chat_model, context_builder=context_builder, tokenizer=tokenizer
)

In [5]:
resp = await search.search("on what basis the customers are segmented?")

  0%|          | 0/3 [00:00<?, ?it/s]        Reached token limit - reverting to previous context state
Reached token limit - reverting to previous context state
Reached token limit - reverting to previous context state
  0%|          | 0/3 [00:00<?, ?it/s]        Reached token limit - reverting to previous context state
Reached token limit - reverting to previous context state
Reached token limit - reverting to previous context state
  0%|          | 0/3 [00:00<?, ?it/s]        Reached token limit - reverting to previous context state
Reached token limit - reverting to previous context state
Reached token limit - reverting to previous context state
                                             

In [6]:
resp.response

'Customers are segmented based on the following criteria, as outlined in the data:\n\n1. **Geographic Region**:  \n   The `REGION` entity (Data: 30) is explicitly identified as a key dimension for segmentation. Metrics like average lifetime value (AVG_LTV) and average orders (AVG_ORDERS) are aggregated by region and customer segment [Data: Reports (0); Sources (6, 7)].\n\n2. **Behavioral Metrics (Spending Patterns)**:  \n   The **BASE model** classifies customers into tiers (e.g., Platinum, Gold, Silver, Bronze) using **total_spent thresholds** [Data: Reports (0); Sources (6)]. This is further refined by joining cleaned customer data (with active status) to intermediate metrics tables [Data: Sources (6, 7)].\n\n3. **Data Quality and Engagement Status**:  \n   - The **`is_active` flag** (normalized to `TRUE` if NULL) ensures only valid, active customers are included in segmentation [Data: Sources (7); Reports (0)].  \n   - Emails are filtered to retain only records containing an `@` cha