In [6]:
import os
from pathlib import Path

import pandas as pd
import tiktoken

from graphrag.config.enums import ModelType
from graphrag.config.models.drift_search_config import DRIFTSearchConfig
from graphrag.config.models.language_model_config import LanguageModelConfig
from graphrag.language_model.manager import ModelManager
from graphrag.query.indexer_adapters import (
    read_indexer_entities,
    read_indexer_relationships,
    read_indexer_report_embeddings,
    read_indexer_reports,
    read_indexer_text_units,
)
from graphrag.query.structured_search.drift_search.drift_context import (
    DRIFTSearchContextBuilder,
)
from graphrag.query.structured_search.drift_search.search import DRIFTSearch
from graphrag.vector_stores.lancedb import LanceDBVectorStore

INPUT_DIR = "./output"
LANCEDB_URI = f"{INPUT_DIR}/lancedb"

COMMUNITY_REPORT_TABLE = "community_reports"
COMMUNITY_TABLE = "communities"
ENTITY_TABLE = "entities"
RELATIONSHIP_TABLE = "relationships"
#COVARIATE_TABLE = "covariates"
TEXT_UNIT_TABLE = "text_units"
COMMUNITY_LEVEL = 2


# read nodes table to get community and degree data
entity_df = pd.read_parquet(f"{INPUT_DIR}/{ENTITY_TABLE}.parquet")
community_df = pd.read_parquet(f"{INPUT_DIR}/{COMMUNITY_TABLE}.parquet")

print(f"Entity df columns: {entity_df.columns}")

entities = read_indexer_entities(entity_df, community_df, COMMUNITY_LEVEL)

# load description embeddings to an in-memory lancedb vectorstore
# to connect to a remote db, specify url and port values.
description_embedding_store = LanceDBVectorStore(
    collection_name="default-entity-description",
)
description_embedding_store.connect(db_uri=LANCEDB_URI)

full_content_embedding_store = LanceDBVectorStore(
    collection_name="default-community-full_content",
)
full_content_embedding_store.connect(db_uri=LANCEDB_URI)

print(f"Entity count: {len(entity_df)}")
entity_df.head()

relationship_df = pd.read_parquet(f"{INPUT_DIR}/{RELATIONSHIP_TABLE}.parquet")
relationships = read_indexer_relationships(relationship_df)

print(f"Relationship count: {len(relationship_df)}")
relationship_df.head()

text_unit_df = pd.read_parquet(f"{INPUT_DIR}/{TEXT_UNIT_TABLE}.parquet")
text_units = read_indexer_text_units(text_unit_df)

print(f"Text unit records: {len(text_unit_df)}")
text_unit_df.head()

Entity df columns: Index(['id', 'human_readable_id', 'title', 'type', 'description',
       'text_unit_ids', 'frequency', 'degree', 'x', 'y'],
      dtype='object')
Entity count: 1824
Relationship count: 2641
Text unit records: 243


Unnamed: 0,id,human_readable_id,text,n_tokens,document_ids,entity_ids,relationship_ids,covariate_ids
0,422c343a682e7f78dae36ac94f2bf135ec09bd2784744d...,1,Tell us about your PDF experience.\nMicrosoft ...,1200,[d5835eb381e6b16dd4dd9c9b74c2bb54d06463d791e93...,"[3e730b25-df4d-4390-a891-3cb559ab387a, 09903b4...","[3877ccd6-2ae9-46bc-82de-3acf47ff371e, abb9d1b...",[]
1,7f05149428714a222f2c4fa84b8dc05f8991a7a9266cab...,2,"ises and in the cloud. For more information, s...",1200,[d5835eb381e6b16dd4dd9c9b74c2bb54d06463d791e93...,"[3e730b25-df4d-4390-a891-3cb559ab387a, 09903b4...","[3877ccd6-2ae9-46bc-82de-3acf47ff371e, 6b4b719...",[]
2,406c2d7deeffa71c9bcd8922a425f6548d58eabd180b0d...,3,"IoT Hub, Azure SQL DB Change Data Capture (CD...",1200,[d5835eb381e6b16dd4dd9c9b74c2bb54d06463d791e93...,"[3e730b25-df4d-4390-a891-3cb559ab387a, 09903b4...","[3877ccd6-2ae9-46bc-82de-3acf47ff371e, ba7b4a6...",[]
3,03b9935515464b3ba9fa660b98d2b7bd12a24ebe4f6494...,4,"For detailed instructions, see\nMoving your d...",1200,[d5835eb381e6b16dd4dd9c9b74c2bb54d06463d791e93...,"[3e730b25-df4d-4390-a891-3cb559ab387a, 5bffd67...","[ba7b4a6c-1a26-45b7-a0fa-e3dfd981aa88, 6b4b719...",[]
4,0c2447dd1987bf104efc25070271e6f085d4b9c2945d60...,5,", see Canceling, expiring, and closing.\nCance...",1200,[d5835eb381e6b16dd4dd9c9b74c2bb54d06463d791e93...,"[5bffd675-01c6-4f6d-a247-99333ee64f05, c44a355...","[0ad8a4e0-cedf-41fa-bc40-0a41babad623, f08cbdf...",[]


In [7]:
# feedback why should we pass the type to both config and the chat_model etc?

from graphrag.config.enums import ModelType, AuthType
from graphrag.config.models.language_model_config import LanguageModelConfig
from graphrag.language_model.manager import ModelManager
import os
from dotenv import load_dotenv

# Load environment variables from .env file
load_dotenv()

api_key = os.getenv("GRAPHRAG_API_KEY")
llm_model = os.getenv("GRAPHRAG_LLM_MODEL")
embedding_model = os.getenv("GRAPHRAG_EMBEDDING_MODEL")

chat_config = LanguageModelConfig(
    api_key=api_key,
    auth_type=AuthType.APIKey, 
    type=ModelType.AzureOpenAIChat,
    model=llm_model,
    deployment_name=llm_model,
    max_retries=20,
    api_base= os.getenv("GRAPHRAG_API_BASE"),
    api_version="2024-02-15-preview"
)
chat_model = ModelManager().get_or_create_chat_model(
    name="local_search",
    model_type=ModelType.AzureOpenAIChat,
    config=chat_config,
)

token_encoder = tiktoken.encoding_for_model(llm_model)

embedding_config = LanguageModelConfig(
    api_key=api_key,
    auth_type=AuthType.APIKey,
    type=ModelType.AzureOpenAIEmbedding,  # <-- Switch to AzureOpenAIEmbedding
    model=embedding_model,                # <-- This should be your Azure deployment name for embeddings
    deployment_name=embedding_model,      # <-- Same as above
    api_base=os.getenv("GRAPHRAG_API_BASE"),
    api_version="2024-02-15-preview"
)

text_embedder = ModelManager().get_or_create_embedding_model(
    name="local_search_embedding",
    model_type=ModelType.AzureOpenAIEmbedding,
    config=embedding_config,
)

In [8]:
def read_community_reports(
    input_dir: str,
    community_report_table: str = COMMUNITY_REPORT_TABLE,
):
    """Embeds the full content of the community reports and saves the DataFrame with embeddings to the output path."""
    input_path = Path(input_dir) / f"{community_report_table}.parquet"
    return pd.read_parquet(input_path)


report_df = read_community_reports(INPUT_DIR)
reports = read_indexer_reports(
    report_df,
    community_df,
    COMMUNITY_LEVEL,
    content_embedding_col="full_content_embeddings",
)
read_indexer_report_embeddings(reports, full_content_embedding_store)

In [9]:
drift_params = DRIFTSearchConfig(
    temperature=0,
    max_tokens=12_000,
    primer_folds=1,
    drift_k_followups=3,
    n_depth=3,
    n=1,
)

context_builder = DRIFTSearchContextBuilder(
    model=chat_model,
    text_embedder=text_embedder,
    entities=entities,
    relationships=relationships,
    reports=reports,
    entity_text_embeddings=description_embedding_store,
    text_units=text_units,
    token_encoder=token_encoder,
    config=drift_params,
)

search = DRIFTSearch(
    model=chat_model, context_builder=context_builder, token_encoder=token_encoder
)

In [10]:


result = await search.search("how do you do shortcuts?")
print(result.response)

                                             

TypeError: AsyncCompletions.create() got an unexpected keyword argument 'model_params'

In [None]:
print(result.context_data)