In [None]:
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License.

In [None]:
import os
from pathlib import Path

import pandas as pd

import graphrag.api as api
from graphrag.config.create_graphrag_config import create_graphrag_config
from graphrag.config.enums import ModelType
from graphrag.config.models.drift_search_config import DRIFTSearchConfig
from graphrag.config.models.language_model_config import LanguageModelConfig
from graphrag.config.models.vector_store_config import VectorStoreConfig


In [None]:
api_key = os.environ["GRAPHRAG_API_KEY"]
llm_model = os.environ["GRAPHRAG_LLM_MODEL"]
embedding_model = os.environ["GRAPHRAG_EMBEDDING_MODEL"]

INPUT_DIR = "./inputs/operation dulce"
LANCEDB_URI = f"{INPUT_DIR}/lancedb"

COMMUNITY_REPORT_TABLE = "community_reports"
COMMUNITY_TABLE = "communities"
ENTITY_TABLE = "entities"
RELATIONSHIP_TABLE = "relationships"
TEXT_UNIT_TABLE = "text_units"

COMMUNITY_LEVEL = 2


In [None]:
chat_config = LanguageModelConfig(
    api_key=api_key,
    type=ModelType.OpenAIChat,
    model=llm_model,
    max_retries=20,
)

embedding_config = LanguageModelConfig(
    api_key=api_key,
    type=ModelType.OpenAIEmbedding,
    model=embedding_model,
    max_retries=20,
)

vector_store_config = VectorStoreConfig(
    type="lancedb",
    db_uri=str(Path(INPUT_DIR).resolve() / "lancedb"),
    container_name="default",
    overwrite=True,
)

config = create_graphrag_config(
    {
        "models": {
            "default_chat_model": chat_config,
            "default_embedding_model": embedding_config,
        },
        "local_search": {
            "chat_model_id": "default_chat_model",
            "embedding_model_id": "default_embedding_model",
        },
        "drift_search": DRIFTSearchConfig(
            chat_model_id="default_chat_model",
            embedding_model_id="default_embedding_model",
        ),
        "vector_store": {
            "default_vector_store": vector_store_config,
        },
    },
    root_dir=Path("."),
)

community_df = pd.read_parquet(f"{INPUT_DIR}/{COMMUNITY_TABLE}.parquet")
entity_df = pd.read_parquet(f"{INPUT_DIR}/{ENTITY_TABLE}.parquet")
relationship_df = pd.read_parquet(f"{INPUT_DIR}/{RELATIONSHIP_TABLE}.parquet")
report_df = pd.read_parquet(f"{INPUT_DIR}/{COMMUNITY_REPORT_TABLE}.parquet")
text_unit_df = pd.read_parquet(f"{INPUT_DIR}/{TEXT_UNIT_TABLE}.parquet")

print("DataFrames loaded:")
print(f"  Communities: {len(community_df)}")
print(f"  Community reports: {len(report_df)}")
print(f"  Entities: {len(entity_df)}")
print(f"  Relationships: {len(relationship_df)}")
print(f"  Text units: {len(text_unit_df)}")


#### Run drift search


In [None]:
response, context = await api.drift_search(
    config=config,
    entities=entity_df,
    communities=community_df,
    community_reports=report_df,
    text_units=text_unit_df,
    relationships=relationship_df,
    community_level=COMMUNITY_LEVEL,
    response_type="Multiple Paragraphs",
    query="Who is Agent Mercer?",
)
response


#### Inspect DRIFT context data


In [None]:
context
