In [1]:
import os
from pathlib import Path

import pandas as pd

from graphrag.config.enums import ModelType
from graphrag.config.models.drift_search_config import DRIFTSearchConfig
from graphrag.config.models.language_model_config import LanguageModelConfig
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
from graphrag.language_model.manager import ModelManager
from graphrag.query.indexer_adapters import (
    read_indexer_entities,
    read_indexer_relationships,
    read_indexer_report_embeddings,
    read_indexer_reports,
    read_indexer_text_units,
)
from graphrag.query.structured_search.drift_search.drift_context import (
    DRIFTSearchContextBuilder,
)
from graphrag.query.structured_search.drift_search.search import DRIFTSearch
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.vector_stores.lancedb import LanceDBVectorStore

INPUT_DIR = "../output"
LANCEDB_URI = f"{INPUT_DIR}/lancedb"

COMMUNITY_REPORT_TABLE = "community_reports"
COMMUNITY_TABLE = "communities"
ENTITY_TABLE = "entities"
RELATIONSHIP_TABLE = "relationships"
COVARIATE_TABLE = "covariates"
TEXT_UNIT_TABLE = "text_units"
COMMUNITY_LEVEL = 2


# read nodes table to get community and degree data
entity_df = pd.read_parquet(f"{INPUT_DIR}/{ENTITY_TABLE}.parquet")
community_df = pd.read_parquet(f"{INPUT_DIR}/{COMMUNITY_TABLE}.parquet")

print(f"Entity df columns: {entity_df.columns}")

entities = read_indexer_entities(entity_df, community_df, COMMUNITY_LEVEL)

# load description embeddings to an in-memory lancedb vectorstore
# to connect to a remote db, specify url and port values.
description_embedding_store = LanceDBVectorStore(
    vector_store_schema_config=VectorStoreSchemaConfig(
        index_name="default-entity-description"
    ),
)
description_embedding_store.connect(db_uri=LANCEDB_URI)

full_content_embedding_store = LanceDBVectorStore(
    vector_store_schema_config=VectorStoreSchemaConfig(
        index_name="default-community-full_content"
    )
)
full_content_embedding_store.connect(db_uri=LANCEDB_URI)

print(f"Entity count: {len(entity_df)}")
entity_df.head()

relationship_df = pd.read_parquet(f"{INPUT_DIR}/{RELATIONSHIP_TABLE}.parquet")
relationships = read_indexer_relationships(relationship_df)

print(f"Relationship count: {len(relationship_df)}")
relationship_df.head()

text_unit_df = pd.read_parquet(f"{INPUT_DIR}/{TEXT_UNIT_TABLE}.parquet")
text_units = read_indexer_text_units(text_unit_df)

print(f"Text unit records: {len(text_unit_df)}")
text_unit_df.head()

Entity df columns: Index(['id', 'human_readable_id', 'title', 'type', 'description',
       'text_unit_ids', 'frequency', 'degree', 'x', 'y'],
      dtype='object')
Entity count: 779
Relationship count: 923
Text unit records: 24


Unnamed: 0,id,human_readable_id,text,n_tokens,document_ids,entity_ids,relationship_ids,covariate_ids
0,ff10b7b5be1317511cc3907906969758360dec82ea361a...,0,## Table Name: REPORTING_DEV.DBT_MAHMUDNABI_gt...,1200,[00b79e7cc5f646a669965928ff21d560f53f2d2ca46d7...,"[ae58bb29-eacc-4c31-948f-531417039a47, 1676e1b...","[434c4f3f-1436-42cb-8c7f-047e4d457948, 73ee415...",[]
1,51bca52aa57117b2edc8766bb29693c83e090fe7c5c7de...,1,– Final RSAC calculation in USD. \n\n## Logi...,493,[00b79e7cc5f646a669965928ff21d560f53f2d2ca46d7...,"[1676e1b7-6196-4da9-b2d4-750e993f07a1, 21bf732...","[204fcfcf-e9c7-4b00-90e3-e6eabd2d12a3, 1411182...",[]
2,6a36014cd02265a1ee85377d17176e124289e1382c9d83...,2,## Table Name: REPORTING_DEV.DBT_MAHMUDNABI_gt...,1200,[065e17268d0619f2b6281ba451e91b863f08f8fe4e64a...,"[5f99b9ca-4796-4f9b-92b4-55c7debba513, d6c03dd...","[d95fb42f-bc31-4a08-899f-12f3a06eda01, afd2492...",[]
3,7809e630d7f90df516b6d6acf64702bbe19eada9ffb6dc...,3,ABI_snapshots.sfdc_opportunity_line_item_analy...,1200,[065e17268d0619f2b6281ba451e91b863f08f8fe4e64a...,"[5e3a9cb1-80fe-4704-95a5-364f6fa2cff9, 6debdb2...","[8a55a3b1-65bf-496a-b14a-820262b3925c, 1293da7...",[]
4,c4b3a498009191e3694cfed96a8f92b256b701adafd452...,4,_DEV.DBT_MAHMUDNABI_snapshots.sfdc_opportunity...,203,[065e17268d0619f2b6281ba451e91b863f08f8fe4e64a...,"[bbdb67ce-0423-4161-9517-ba7331cd4341, f4c024a...","[4601022c-1141-48df-909d-0080e3f037de, 87e55ca...",[]


In [2]:
from graphrag.config.enums import ModelType
from graphrag.config.models.language_model_config import LanguageModelConfig
from graphrag.language_model.manager import ModelManager
from graphrag.tokenizer.get_tokenizer import get_tokenizer

api_key = os.environ["GEMINI_API_KEY"]
graphrag_api_key = os.environ["GRAPHRAG_API_KEY"]

chat_config = LanguageModelConfig(
    api_key=api_key,
    auth_type="api_key",
    model_supports_json=True,
    # api_base="http://localhost:11434/v1",
    type=ModelType.Chat,
    model_provider="gemini",
    model="gemini-2.5-flash-lite",
    max_retries=20,
    temperature=0.0,
)
chat_model = ModelManager().get_or_create_chat_model(
    name="local_search",
    model_type=ModelType.Chat,
    config=chat_config,
)

embedding_config = LanguageModelConfig(
    api_key=graphrag_api_key,
    type=ModelType.Embedding,
    model_provider="openai",
    model="all-minilm:latest",
    max_retries=20,
    api_base="http://localhost:11434/v1",
)

text_embedder = ModelManager().get_or_create_embedding_model(
    name="local_search_embedding",
    model_type=ModelType.Embedding,
    config=embedding_config,
)

tokenizer = get_tokenizer(chat_config)

In [3]:
def read_community_reports(
    input_dir: str,
    community_report_table: str = COMMUNITY_REPORT_TABLE,
):
    """Embeds the full content of the community reports and saves the DataFrame with embeddings to the output path."""
    input_path = Path(input_dir) / f"{community_report_table}.parquet"
    return pd.read_parquet(input_path)


report_df = read_community_reports(INPUT_DIR)
reports = read_indexer_reports(
    report_df,
    community_df,
    COMMUNITY_LEVEL,
    content_embedding_col="full_content_embeddings",
)
read_indexer_report_embeddings(reports, full_content_embedding_store)

In [4]:
drift_params = DRIFTSearchConfig(
    temperature=0,
    max_tokens=12_000,
    primer_folds=1,
    drift_k_followups=3,
    n_depth=3,
    n=1,
)

context_builder = DRIFTSearchContextBuilder(
    model=chat_model,
    text_embedder=text_embedder,
    entities=entities,
    relationships=relationships,
    reports=reports,
    entity_text_embeddings=description_embedding_store,
    text_units=text_units,
    tokenizer=tokenizer,
    config=drift_params,
)

search = DRIFTSearch(
    model=chat_model, context_builder=context_builder, tokenizer=tokenizer
)

resp = await search.search("")
resp.response

In [5]:
resp = await search.search("What is the trend of closed opportunities over the last four quarters? Can you write me a sql query ?")


Reached token limit - reverting to previous context state
Reached token limit - reverting to previous context state
Reached token limit - reverting to previous context state
No answer found for query: What are the specific stages that are excluded from the `CLOSED_OPPORTUNITIES` CTE in the Opportunity RSAC Analytics Community?
No follow-up actions found for response: {}
 67%|██████▋   | 2/3 [00:02<00:01,  1.33s/it]No answer found for query: Are there any specific product categories that show a different trend in closed opportunities?
No follow-up actions found for response: {}
No follow-up actions for action: What are the specific stages that are excluded from the `CLOSED_OPPORTUNITIES` CTE in the Opportunity RSAC Analytics Community?
No follow-up actions for action: Are there any specific product categories that show a different trend in closed opportunities?
  0%|          | 0/3 [00:00<?, ?it/s]Reached token limit - reverting to previous context state
Reached token limit - reverting 

In [7]:
resp.response

'To determine the trend of closed opportunities over the last four quarters, you would need to count the number of opportunities that entered the "Closed Won" stage for each of those quarters. The `Sales Opportunity Stage Tracking Community` and `Opportunity RSAC Analytics Community` are relevant for this analysis. The `OPPORTUNITY_RSAC_QUARTERLY` model appears to be the most direct source for aggregated quarterly data.\n\nHere is a conceptual SQL query to achieve this:\n\n```sql\nWITH QuarterlyClosedOpportunities AS (\n    SELECT\n        STRFTIME(\'%Y-%q\', DATE_ENTERED_STAGE_CLOSED_WON) AS closed_quarter,\n        COUNT(DISTINCT opportunity_id) AS number_of_closed_opportunities -- Assuming an opportunity_id exists\n    FROM\n        your_sales_data_table -- Replace with the actual table containing opportunity stage dates\n    WHERE\n        DATE_ENTERED_STAGE_CLOSED_WON IS NOT NULL\n        AND DATE_ENTERED_STAGE_CLOSED_WON >= DATE(\'now\', \'-4 quarters\') -- Filter for the last fo