# Causal Search Demo - Query and Context Analysis

This notebook demonstrates how to use the Causal Search method in GraphRAG and inspect the context used to generate responses. Causal Search performs causal analysis on knowledge graphs through a two-stage process:

1. **Stage 1**: Extract extended graph information (k + s nodes) and generate causal analysis report
2. **Stage 2**: Use the causal report to generate final response to user query

## Key Features

- Extended node extraction beyond local search limits
- Two-stage processing for comprehensive causal analysis
- Automatic output saving to data folders
- Configurable parameters for retrieval breadth and context proportions
- Integration with existing GraphRAG pipeline
- **Context inspection**: See exactly what data was used to generate responses

## Prerequisites

Before running this notebook, ensure you have:

1. Run the GraphRAG indexing pipeline to generate entities, relationships, and community reports
2. Set up your configuration in `settings.yaml` with causal search parameters
3. Configured your language models and API keys

---

# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License.

import os
import asyncio
import json
import pandas as pd
import tiktoken
from pathlib import Path
from typing import Any, Dict, List

# GraphRAG imports
from graphrag.config.enums import ModelType
from graphrag.config.load_config import load_config
from graphrag.config.models.language_model_config import LanguageModelConfig
from graphrag.language_model.manager import ModelManager
from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey
from graphrag.query.factory import get_causal_search_engine
from graphrag.query.indexer_adapters import (
    read_indexer_covariates,
    read_indexer_entities,
    read_indexer_relationships,
    read_indexer_reports,
    read_indexer_text_units,
)
from graphrag.query.structured_search.causal_search.search import CausalSearchError
from graphrag.query.structured_search.local_search.mixed_context import (
    LocalSearchMixedContext,
)
from graphrag.vector_stores.lancedb import LanceDBVectorStore

# IPython display utilities
from IPython.display import Markdown, display

## Configuration Setup

First, let's load the GraphRAG configuration and set up the environment.

In [None]:
# Configuration setup
ROOT_DIR = Path("./ragtest")  # Adjust this path to your project root
CONFIG_FILE = None  # Use default settings.yaml

# Load configuration
try:
    config = load_config(ROOT_DIR, CONFIG_FILE)
    print("✅ Configuration loaded successfully")
    print(f"📁 Root directory: {ROOT_DIR}")
    print(f"🔧 Causal search s_parameter: {config.causal_search.s_parameter}")
    print(f"🔧 Causal search top_k_entities: {config.causal_search.top_k_mapped_entities}")
    print(f"🔧 Causal search max_context_tokens: {config.causal_search.max_context_tokens}")
except Exception as e:
    print(f"❌ Failed to load configuration: {e}")
    raise

## Data Loading

Load the required data from your GraphRAG pipeline outputs using the same functions as the visualization notebook.

In [None]:
# Data loading setup
INPUT_DIR = f"{ROOT_DIR}/output"
LANCEDB_URI = f"{INPUT_DIR}/lancedb"

COMMUNITY_REPORT_TABLE = "community_reports"
COMMUNITY_TABLE = "communities"
ENTITY_TABLE = "entities"
RELATIONSHIP_TABLE = "relationships"
COVARIATE_TABLE = "covariates"
TEXT_UNIT_TABLE = "text_units"
COMMUNITY_LEVEL = 2

### Load tables to dataframes

#### Read entities

In [None]:
# read nodes table to get community and degree data
entity_df = pd.read_parquet(f"{INPUT_DIR}/{ENTITY_TABLE}.parquet")
community_df = pd.read_parquet(f"{INPUT_DIR}/{COMMUNITY_TABLE}.parquet")

print(f"✅ Loaded {len(entity_df)} entities")
print(f"✅ Loaded {len(community_df)} communities")

#### Read relationships

In [None]:
relationship_df = pd.read_parquet(f"{INPUT_DIR}/{RELATIONSHIP_TABLE}.parquet")
relationships = read_indexer_relationships(relationship_df)

print(f"✅ Loaded {len(relationship_df)} relationships")

#### Read other data tables

In [None]:
# Load text units
text_unit_df = pd.read_parquet(f"{INPUT_DIR}/{TEXT_UNIT_TABLE}.parquet")
text_units = read_indexer_text_units(text_unit_df)

# Load community reports
report_df = pd.read_parquet(f"{INPUT_DIR}/{COMMUNITY_REPORT_TABLE}.parquet")
reports = read_indexer_reports(report_df, community_df, COMMUNITY_LEVEL)

# Load covariates if they exist
try:
    covariate_df = pd.read_parquet(f"{INPUT_DIR}/{COVARIATE_TABLE}.parquet")
    claims = read_indexer_covariates(covariate_df)
    covariates = {"claims": claims}
    print(f"✅ Loaded {len(claims)} covariates")
except FileNotFoundError:
    print("ℹ️  No covariates found, proceeding without covariates")
    covariates = {}

print(f"✅ Loaded {len(text_units)} text units")
print(f"✅ Loaded {len(reports)} community reports")

## Model Setup

Set up the language models and context builder using the same approach as the visualization notebook.

In [None]:
# Get model configurations from the loaded config
chat_model_config = config.get_language_model_config("default_chat_model")
embedding_model_config = config.get_language_model_config("default_embedding_model")

# Create chat model
chat_model = ModelManager().get_or_create_chat_model(
    name="causal_search",
    model_type=chat_model_config.type,
    config=chat_model_config,
)

# Create token encoder
token_encoder = tiktoken.encoding_for_model(chat_model_config.model)

# Create embedding model
text_embedder = ModelManager().get_or_create_embedding_model(
    name="causal_search_embedding",
    model_type=embedding_model_config.type,
    config=embedding_model_config,
)

# Create vector store
description_embedding_store = LanceDBVectorStore(
    collection_name="default-entity-description",
)
description_embedding_store.connect(db_uri=LANCEDB_URI)

print("✅ Models and vector store setup complete")

## Context Builder Setup

Create the context builder using the same parameters as the visualization notebook.

In [None]:
# Context builder parameters (same as visualization notebook)
context_builder_params = {
    "text_unit_prop": 0.5,
    "community_prop": 0.1,
    "conversation_history_max_turns": 5,
    "conversation_history_user_turns_only": True,
    "top_k_mapped_entities": 10,  # Increased for causal search
    "top_k_relationships": 10,     # Increased for causal search
    "include_entity_rank": True,
    "include_relationship_weight": True,
    "include_community_rank": False,
    "return_candidate_context": False,
    "embedding_vectorstore_key": EntityVectorStoreKey.ID,
    "max_tokens": 80_000,
}

# Create context builder
context_builder = LocalSearchMixedContext(
    community_reports=reports,
    text_units=text_units,
    entities=entities,
    relationships=relationships,
    covariates=covariates,
    entity_text_embeddings=description_embedding_store,
    embedding_vectorstore_key=EntityVectorStoreKey.ID,
    text_embedder=text_embedder,
    token_encoder=token_encoder,
)

print("✅ Context builder setup complete")

## Causal Search Engine Setup

Create the causal search engine with the same model parameters as the visualization notebook.

In [None]:
# Model parameters (same as visualization notebook)
model_params = {
    "max_tokens": 4_000,  # Adjusted for gpt-4-turbo-preview
    "temperature": 0.0,
}

# Create causal search engine
causal_search_engine = get_causal_search_engine(
    model=chat_model,
    context_builder=context_builder,
    token_encoder=token_encoder,
    model_params=model_params,
    context_builder_params=context_builder_params,
    s_parameter=3,  # Additional nodes for causal analysis
    max_context_tokens=12_000,
)

print("✅ Causal search engine setup complete")

## Causal Search Example

Now let's run a causal search query and inspect the context used to generate the response.

### Run causal search on sample queries

In [None]:
# Sample query for causal analysis
question = "What are the causal relationships in this dataset?"
print(f"🔍 Query: {question}")

# Execute causal search
try:
    result = await causal_search_engine.search(question)
    print("✅ Causal search completed successfully!")
except Exception as e:
    print(f"❌ Causal search failed: {e}")
    raise

### Display the response

In [None]:
# Display as formatted Markdown
print("\n📝 Causal Search Response:")
print("=" * 50)
display(Markdown(result.response))

## Inspecting the Context Data

Now let's examine exactly what data was used to generate the response. This is the key part that shows the context filtering in action.

### Context data overview

In [None]:
print("\n🔍 Context Data Overview:")
print("=" * 50)

if hasattr(result, 'context_data') and result.context_data:
    print(f"📊 Entities: {len(result.context_data.get('entities', []))}")
    print(f"🔗 Relationships: {len(result.context_data.get('relationships', []))}")
    print(f"📄 Text Units: {len(result.context_data.get('text_units', []))}")
    print(f"🏘️  Community Reports: {len(result.context_data.get('community_reports', []))}")
else:
    print("ℹ️  No context_data available in result")

### Inspect entities used in context

In [None]:
print("\n🏷️  Entities Used in Context:")
print("=" * 50)

if hasattr(result, 'context_data') and result.context_data and 'entities' in result.context_data:
    entities_df = result.context_data['entities']
    if not entities_df.empty:
        # Show key fields
        display_cols = ['entity', 'description', 'rank', 'type']
        available_cols = [col for col in display_cols if col in entities_df.columns]
        
        if available_cols:
            display(entities_df[available_cols].head(10))
        else:
            display(entities_df.head(10))
        
        print(f"\n📈 Total entities in context: {len(entities_df)}")
    else:
        print("ℹ️  No entities found in context")
else:
    print("ℹ️  No entities data available")

### Inspect relationships used in context

In [None]:
print("\n🔗 Relationships Used in Context:")
print("=" * 50)

if hasattr(result, 'context_data') and result.context_data and 'relationships' in result.context_data:
    relationships_df = result.context_data['relationships']
    if not relationships_df.empty:
        # Show key fields
        display_cols = ['source', 'target', 'description', 'weight', 'rank']
        available_cols = [col for col in display_cols if col in relationships_df.columns]
        
        if available_cols:
            display(relationships_df[available_cols].head(10))
        else:
            display(relationships_df[available_cols].head(10))
        
        print(f"\n📈 Total relationships in context: {len(relationships_df)}")
    else:
        print("ℹ️  No relationships found in context")
else:
    print("ℹ️  No relationships data available")

### Inspect text units used in context

In [None]:
print("\n📄 Text Units Used in Context:")
print("=" * 50)

if hasattr(result, 'context_data') and result.context_data and 'text_units' in result.context_data:
    text_units_df = result.context_data['text_units']
    if not text_units_df.empty:
        # Show key fields
        display_cols = ['text', 'n_tokens']
        available_cols = [col for col in display_cols if col in relationships_df.columns]
        
        if available_cols:
            # Truncate text for display
            display_df = text_units_df[available_cols].copy()
            if 'text' in display_df.columns:
                display_df['text'] = display_df['text'].str[:200] + '...'
            display(display_df.head(10))
        else:
            display(text_units_df.head(10))
        
        print(f"\n📈 Total text units in context: {len(text_units_df)}")
    else:
        print("ℹ️  No text units found in context")
else:
    print("ℹ️  No text units data available")

### Inspect community reports used in context

In [None]:
print("\n🏘️  Community Reports Used in Context:")
print("=" * 50)

if hasattr(result, 'context_data') and result.context_data and 'community_reports' in result.context_data:
    community_reports_df = result.context_data['community_reports']
    if not community_reports_df.empty:
        # Show key fields
        display_cols = ['community_id', 'summary', 'description']
        available_cols = [col for col in display_cols if col in community_reports_df.columns]
        
        if available_cols:
            # Truncate summary for display
            display_df = community_reports_df[available_cols].copy()
            if 'summary' in display_df.columns:
                display_df['summary'] = display_df['summary'].str[:200] + '...'
            display(display_df[available_cols].head(10))
        else:
            display(community_reports_df.head(10))
        
        print(f"\n📈 Total community reports in context: {len(community_reports_df)}")
    else:
        print("ℹ️  No community reports found in context")
else:
    print("ℹ️  No community reports data available")

## Context Filtering Analysis

Let's analyze how the context filtering worked and compare it to the original data.

### Compare original vs. filtered data

In [None]:
print("\n🔍 Context Filtering Analysis:")
print("=" * 50)

print(f"📊 Original data:")
print(f"   - Entities: {len(entity_df)}")
print(f"   - Relationships: {len(relationship_df)}")
print(f"   - Text Units: {len(text_unit_df)}")
print(f"   - Community Reports: {len(report_df)}")

if hasattr(result, 'context_data') and result.context_data:
    print(f"\n🎯 Filtered context data:")
    print(f"   - Entities: {len(result.context_data.get('entities', []))}")
    print(f"   - Relationships: {len(result.context_data.get('relationships', []))}")
    print(f"   - Text Units: {len(result.context_data.get('text_units', []))}")
    print(f"   - Community Reports: {len(result.context_data.get('community_reports', []))}")
    
    # Calculate filtering ratios
    entity_ratio = len(result.context_data.get('entities', [])) / len(entity_df) * 100
    relationship_ratio = len(result.context_data.get('relationships', [])) / len(relationship_df) * 100
    text_unit_ratio = len(result.context_data.get('text_units', [])) / len(text_unit_df) * 100
    
    print(f"\n📈 Filtering ratios:")
    print(f"   - Entities: {entity_ratio:.1f}% kept")
    print(f"   - Relationships: {relationship_ratio:.1f}% kept")
    print(f"   - Text Units: {text_unit_ratio:.1f}% kept")

### Token usage analysis

In [None]:
print("\n🔢 Token Usage Analysis:")
print("=" * 50)

if hasattr(result, 'context_data') and result.context_data:
    # Estimate token usage for each data type
    total_tokens = 0
    
    # Entities tokens
    if 'entities' in result.context_data and not result.context_data['entities'].empty:
        entities_text = result.context_data['entities'].to_string()
        entities_tokens = len(token_encoder.encode(entities_text))
        total_tokens += entities_tokens
        print(f"   - Entities: ~{entities_tokens:,} tokens")
    
    # Relationships tokens
    if 'relationships' in result.context_data and not result.context_data['relationships'].empty:
        relationships_text = result.context_data['relationships'].to_string()
        relationships_tokens = len(token_encoder.encode(relationships_text))
        total_tokens += relationships_tokens
        print(f"   - Relationships: ~{relationships_tokens:,} tokens")
    
    # Text units tokens
    if 'text_units' in result.context_data and not result.context_data['text_units'].empty:
        text_units_text = result.context_data['text_units'].to_string()
        text_units_tokens = len(token_encoder.encode(text_units_text))
        total_tokens += text_units_tokens
        print(f"   - Text Units: ~{text_units_tokens:,} tokens")
    
    print(f"\n📊 Total estimated tokens: ~{total_tokens:,}")
    print(f"🎯 Target limit: 8,000 tokens (network data portion)")

## Advanced Context Inspection

Let's look deeper into the context building process and see how the filtering decisions were made.

### Check context builder parameters

In [None]:
print("\n⚙️  Context Builder Parameters Used:")
print("=" * 50)

for key, value in context_builder_params.items():
    print(f"   - {key}: {value}")

### Check model parameters

In [None]:
print("\n🤖 Model Parameters Used:")
print("=" * 50)

for key, value in model_params.items():
    print(f"   - {key}: {value}")

### Check causal search parameters

In [None]:
print("\n🔍 Causal Search Parameters:")
print("=" * 50)

print(f"   - s_parameter: {causal_search_engine.s_parameter}")
print(f"   - max_context_tokens: {causal_search_engine.max_context_tokens}")

## Summary

This notebook demonstrates:

1. **Data Loading**: Using the same functions as the visualization notebook
2. **Model Setup**: Consistent with the visualization notebook approach
3. **Context Building**: Same parameters and structure
4. **Causal Search**: Extended node extraction and two-stage processing
5. **Context Inspection**: Detailed analysis of what data was used
6. **Filtering Analysis**: Understanding how context filtering works

The key insight is that causal search uses **intelligent filtering** to ensure:
- **LLM Compatibility**: Data fits within model context limits
- **Relevance**: Most important entities/relationships are preserved
- **Performance**: Efficient processing without context length errors

The apparent "loss" of data (e.g., 40+ nodes → 7 entities) is actually **smart optimization** that preserves the most relevant information while maintaining system stability.