# Agentic Graph RAG - Interactive Prototype

This notebook demonstrates the key components of the Graph RAG system:
1. Schema and ontology
2. Document ingestion
3. Graph exploration
4. Query processing with epistemic state tracking

In [None]:
# Setup path
import sys
sys.path.insert(0, '..')

from dotenv import load_dotenv
load_dotenv()

## 1. Domain Ontology

The ontology defines the schema for our knowledge graph.

In [None]:
from src.schema.ontology import create_quant_finance_ontology

ontology = create_quant_finance_ontology()

print(f"Ontology: {ontology.name}")
print(f"\nNode Types: {list(ontology.node_schemas.keys())}")
print(f"\nEdge Types: {list(ontology.edge_schemas.keys())}")
print(f"\nDomain Terms: {list(ontology.domain_terms.items())[:5]}")

In [None]:
# View the Cypher schema representation
print(ontology.to_cypher_schema())

## 2. Graph Store

Initialize the NetworkX store and add sample data.

In [None]:
from src.graph.networkx_store import NetworkXStore
from src.schema.ontology import Entity, Relationship

store = NetworkXStore()
store.initialize(ontology)

# Add sample entities
entities = [
    Entity(id="spx", node_type="Instrument", properties={"symbol": "SPX", "name": "S&P 500 Index", "asset_class": "equity_index"}),
    Entity(id="vix", node_type="Instrument", properties={"symbol": "VIX", "name": "CBOE Volatility Index", "asset_class": "volatility_index"}),
    Entity(id="straddle", node_type="Strategy", properties={"name": "Long Straddle", "strategy_type": "volatility", "risk_profile": "long_vol"}),
    Entity(id="iron_condor", node_type="Strategy", properties={"name": "Iron Condor", "strategy_type": "income", "risk_profile": "short_vol"}),
    Entity(id="high_vol", node_type="MarketCondition", properties={"name": "High Volatility Regime", "vix_range": "25-40"}),
    Entity(id="low_vol", node_type="MarketCondition", properties={"name": "Low Volatility Regime", "vix_range": "12-18"}),
    Entity(id="vega", node_type="RiskFactor", properties={"name": "Vega", "greek": "vega", "description": "Sensitivity to volatility"}),
]

for entity in entities:
    store.add_entity(entity)

# Add relationships
relationships = [
    Relationship(id="r1", edge_type="TRADES", source_id="straddle", target_id="spx", properties={"direction": "long"}),
    Relationship(id="r2", edge_type="PERFORMS_IN", source_id="straddle", target_id="high_vol", properties={"expected_pnl": "positive"}),
    Relationship(id="r3", edge_type="PERFORMS_IN", source_id="iron_condor", target_id="low_vol", properties={"expected_pnl": "positive"}),
    Relationship(id="r4", edge_type="HEDGES", source_id="straddle", target_id="vega", properties={"hedge_ratio": 1.0}),
    Relationship(id="r5", edge_type="AFFECTED_BY", source_id="spx", target_id="vega", properties={"sensitivity": "high"}),
]

for rel in relationships:
    store.add_relationship(rel)

print(f"Added {len(entities)} entities and {len(relationships)} relationships")
print(f"\nGraph statistics: {store.get_statistics()}")

## 3. World State (Epistemic Tracking)

The world state tracks beliefs, uncertainties, and answer completeness.

In [None]:
from src.schema.world_state import WorldState, Belief, BeliefStatus, Uncertainty, UncertaintyType

# Create a world state for a query
state = WorldState(
    original_query="What happens to a straddle when VIX spikes?",
    max_iterations=10,
)

# Add a belief
state.add_belief(Belief(
    id="b1",
    content="Straddles profit from volatility increases",
    status=BeliefStatus.CONFIRMED,
    confidence=0.9,
))

# Add an uncertainty
state.add_uncertainty(Uncertainty(
    id="u1",
    description="How much does the straddle profit for a 20% VIX spike?",
    uncertainty_type=UncertaintyType.COUNTERFACTUAL,
    priority=0.8,
))

state.answer_completeness = 0.6

print(state.to_context_string())
print(f"\nRouting decision: {state.get_routing_decision()}")

## 4. Simulation Tools

Tools for quantitative scenario analysis.

In [None]:
from src.tools.simulation_tools import SimulationTools

# Calculate option Greeks
greeks = SimulationTools.calculate_black_scholes_greeks(
    spot=4500,
    strike=4500,
    time_to_expiry=30/365,  # 30 days
    volatility=0.20,
    risk_free_rate=0.05,
    is_call=True,
)

print("ATM Call Greeks:")
print(f"  Delta: {greeks.delta}")
print(f"  Gamma: {greeks.gamma}")
print(f"  Theta: {greeks.theta}")
print(f"  Vega: {greeks.vega}")

In [None]:
# Simulate VIX spike impact on strategies
for strategy in ["long_straddle", "short_straddle", "iron_condor"]:
    impact = SimulationTools.estimate_vix_spike_impact(
        vix_change_percent=20,
        strategy_type=strategy,
        position_size=100000,
    )
    print(f"\n{strategy}:")
    print(f"  P&L Estimate: ${impact.pnl_estimate:,.0f}")
    print(f"  P&L Range: ${impact.pnl_range[0]:,.0f} to ${impact.pnl_range[1]:,.0f}")
    print(f"  Recommendation: {impact.recommendation}")

## 5. Full Query Workflow

Run a complete query through the Graph RAG system.

In [None]:
# This requires GEMINI_API_KEY to be set
import os

if os.getenv("GEMINI_API_KEY"):
    from src.llm.gemini_provider import GeminiProvider
    from src.graph.langgraph_workflow import GraphRAGWorkflow
    
    llm = GeminiProvider()
    workflow = GraphRAGWorkflow(llm, ontology, store, max_iterations=5)
    
    # Run a query
    result = workflow.run_verbose("What happens to a straddle when VIX spikes 20%?")
    
    # Display results
    if result.get("final_answer"):
        answer = result["final_answer"]
        print(f"Answer: {answer.answer}")
        print(f"\nConfidence: {answer.confidence:.0%}")
        
        if answer.citations:
            print("\nCitations:")
            for c in answer.citations[:3]:
                print(f"  - {c.content}")
else:
    print("Set GEMINI_API_KEY to run the full workflow")

## 6. Graph Exploration

Explore the knowledge graph interactively.

In [None]:
# Find all strategies
strategies = store.find_entities_by_type("Strategy")
print("Strategies:")
for s in strategies:
    print(f"  - {s.properties.get('name')}: {s.properties.get('risk_profile')}")

# Get neighbors of straddle
print("\nStraddle connections:")
neighbors = store.get_neighbors("straddle")
for neighbor, rel in neighbors:
    print(f"  -{rel.edge_type}-> {neighbor.properties.get('name', neighbor.id)}")

In [None]:
# Search for entities
results = store.search_entities("volatility")
print("Search results for 'volatility':")
for entity in results:
    print(f"  [{entity.node_type}] {entity.properties.get('name', entity.id)}")