In [1]:
from dotenv import load_dotenv
from neo4j import GraphDatabase
import os
## Load environment variables from .env file
load_dotenv()

api_key = os.getenv("OPENAI_API_KEY")
uri = os.getenv("NEO4J_URI")
user = os.getenv("NEO4J_USERNAME")
password = os.getenv("NEO4J_PASSWORD")
AUTH = (user, password)

In [3]:
from neo4j import GraphDatabase


# --- prerequisites ------------------------------------------------------
driver   = GraphDatabase.driver(uri, auth=(user, password))


## Retriever = Approach - 1

In [35]:
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
embedder = OpenAIEmbeddings(api_key=api_key, model="text-embedding-3-large")
llm = ChatOpenAI(model="gpt-4o-mini", openai_api_key=api_key)

In [5]:
from openai import OpenAI
import json

client = OpenAI(api_key=api_key)  # ideally os.getenv()

# -------------------- Entity Extraction -------------------- #
def extract_entities(query: str) -> list[str]:
    tools = [
        {
            "type": "function",
            "function": {
                "name": "extract_entities",
                "description": "Extracts all proper nouns or named entities from a query.",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "entities": {
                            "type": "array",
                            "items": {"type": "string"},
                            "description": "List of named entities"
                        }
                    },
                    "required": ["entities"]
                }
            }
        }
    ]

    response = client.chat.completions.create(
        model="gpt-4o",
        messages=[
            {"role": "system", "content": "Extract named entities from the user question."},
            {"role": "user", "content": query}
        ],
        tools=tools,
        tool_choice={"type": "function", "function": {"name": "extract_entities"}}
    )

    args = response.choices[0].message.tool_calls[0].function.arguments
    return json.loads(args)["entities"]


In [6]:
# -------------------- Query Intent Classification -------------------- #
def classify_intent(query: str) -> str:
    tools = [
        {
            "type": "function",
            "function": {
                "name": "classify_intent",
                "description": "Classifies the query intent as fact or context.",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "intent": {
                            "type": "string",
                            "enum": ["fact", "context"],
                            "description": "Is it a fact query (asks for a list or factual triple) or a context query (asks for explanation)?"
                        }
                    },
                    "required": ["intent"]
                }
            }
        }
    ]

    response = client.chat.completions.create(
        model="gpt-4o",
        messages=[
            {"role": "system", "content": "Classify the user query as 'fact' or 'context'."},
            {"role": "user", "content": query}
        ],
        tools=tools,
        tool_choice={"type": "function", "function": {"name": "classify_intent"}}
    )

    args = response.choices[0].message.tool_calls[0].function.arguments
    return json.loads(args)["intent"]


In [7]:
query = "List all iPhones released by Apple in 2022"

entities = extract_entities(query)
intent = classify_intent(query)

print("Entities:", entities)
print("Intent:", intent)

Entities: ['iPhones', 'Apple', '2022']
Intent: fact


In [8]:
from typing import List, Dict, Tuple, Callable, Optional, Sequence

class TripleRetriever:
    def __init__(self, driver: GraphDatabase.driver):
        self.driver = driver

    def get_triples(self, query_entities: List[str], limit: int = 20) -> List[Dict]:
        if not query_entities:
            return []

        cypher = """
        MATCH (s)-[r]->(o)
        WHERE toLower(s.name) IN $entities OR toLower(o.name) IN $entities
        RETURN s.name AS subject, type(r) AS relation, o.name AS object
        LIMIT $limit
        """
        with self.driver.session() as session:
            rows = session.run(
                cypher,
                entities=[e.lower() for e in query_entities],
                limit=limit,
            ).data()
        return rows  # List of {"subject": ..., "relation": ..., "object": ...}


In [9]:
# hybrid_retriever.py
from __future__ import annotations
from typing import List, Dict, Tuple, Callable, Optional, Sequence
from collections import defaultdict

import numpy as np
from neo4j import GraphDatabase
from langchain.embeddings.base import Embeddings


class HybridRetriever:
    """
    A *single-entry-point* retriever that combines

    1. **Dense semantic search** over `(:Chunk)` vectors stored in Neo4j
    2. **Structural expansion** (PARENT_OF, NEXT_CHUNK) to keep narrative context
    3. **Knowledge-graph probing** via light-weight NER →  `(:Entity)-[:MENTIONED_IN]->(:Chunk)`
    4. **Reciprocal-Rank Fusion** to merge the three candidate sets

    The end product is an *ordered* list of chunk dictionaries ready to be
    concatenated into your RAG prompt template.

    Parameters
    ----------
    driver :
        An active `neo4j.GraphDatabase.driver`.
    embedder :
        An object that implements `embed_query(text) -> List[float]`
        (e.g. `langchain_openai.OpenAIEmbeddings`).
    index_name :
        Name of the Neo4j vector index created by `create_vector_index`.
    entity_tagger :
        Callable that returns a list of surface strings found in a query.
        Default is a mini-regex fallback that picks capitalised noun-phrases;
        in production you will probably inject spaCy, HuggingFace, or an
        OpenAI function-calling wrapper.
    top_k_dense :
        How many dense seeds to return.
    expand_parents / expand_neighbors :
        How many *per-seed* hierarchical and adjacent chunks to pull in
        (0 disables the respective expansion).
    kg_chunk_limit :
        Per-entity cap when pulling chunks through the KG.
    max_chunks :
        Final cap after fusion/rerank – keeps your context window bounded.
    """

    def __init__(
        self,
        driver: GraphDatabase.driver,
        embedder: Embeddings,
        *,
        index_name: str,
        entity_tagger: Optional[Callable[[str], Sequence[str]]] = None,
        top_k_dense: int = 5,
        expand_parents: int = 1,
        expand_neighbors: int = 1,
        kg_chunk_limit: int = 3,
        max_chunks: int = 12,
    ) -> None:
        self.driver = driver
        self.embedder = embedder
        self.index_name = index_name
        self.top_k_dense = top_k_dense
        self.expand_parents = expand_parents
        self.expand_neighbors = expand_neighbors
        self.kg_chunk_limit = kg_chunk_limit
        self.max_chunks = max_chunks
        self.entity_tagger = entity_tagger or self._naive_capitalised_np

    # --------------------------------------------------------------------- #
    # 1) Dense search (vector index)                                         #
    # --------------------------------------------------------------------- #
    def _dense_seeds(self, query: str) -> Tuple[List[str], List[float]]:
        embedding = self.embedder.embed_query(query)
        cypher = (
            "CALL db.index.vector.queryNodes($index_name, $k, $embed) "
            "YIELD node, score "
            "RETURN node.id AS id, score "
            "ORDER BY score DESC"
        )
        with self.driver.session() as s:
            res = s.run(
                cypher,
                index_name=self.index_name,
                k=self.top_k_dense,
                embed=embedding,
            )
            rows = res.data()
        ids = [r["id"] for r in rows]
        scores = [r["score"] for r in rows]
        return ids, scores

    # --------------------------------------------------------------------- #
    # 2) Context expansion via PARENT_OF / NEXT_CHUNK                       #
    # --------------------------------------------------------------------- #
    def _expand_parents(self, chunk_ids: List[str], hops: int) -> set[str]:
        if hops == 0 or not chunk_ids:
            return set()
        cypher = f"""
            MATCH (c:Chunk)<-[:PARENT_OF*1..{hops}]-(p)
            WHERE c.id IN $ids
            RETURN DISTINCT p.id AS id
        """
        with self.driver.session() as s:
            rows = s.run(cypher, ids=chunk_ids).data()
        return {r["id"] for r in rows}


    def _expand_neighbors(self, chunk_ids: List[str], hops: int) -> set[str]:
        if hops == 0 or not chunk_ids:
            return set()
        cypher = f"""
            MATCH (c1:Chunk)-[:NEXT_CHUNK*1..{hops}]->(c2:Chunk)
            WHERE c1.id IN $ids
            RETURN DISTINCT c2.id AS id
        """
        with self.driver.session() as s:
            rows = s.run(cypher, ids=chunk_ids).data()
        return {r["id"] for r in rows}


    # --------------------------------------------------------------------- #
    # 3) Knowledge-graph probe                                              #
    # --------------------------------------------------------------------- #
    def kg_probe_from_entities(self, entities: List[str]) -> List[str]:
        if not entities:
            return []
        cypher = (
            "MATCH (e)-[:MENTIONED_IN]->(c:Chunk) "
            "WHERE toLower(e.name) IN $ents "
            "RETURN DISTINCT c.id AS id "
            "LIMIT $lim"
        )
        with self.driver.session() as s:
            rows = s.run(
                cypher,
                ents=[e.lower() for e in entities],
                lim=self.kg_chunk_limit * len(entities),
            ).data()
        return [r["id"] for r in rows]


    # --------------------------------------------------------------------- #
    # 4) Reciprocal-Rank-Fusion                                             #
    # --------------------------------------------------------------------- #
    def _rrf_fuse(self, runs: List[Tuple[str, List[str]]], k: int = 60) -> List[Tuple[str, float]]:
        """
        Each `runs[i]` is (label, [id0, id1, ...]).

        RRF score = 1 / (k + rank).  Higher is better.
        """
        scores = defaultdict(float)
        for _, ids in runs:
            for rank, cid in enumerate(ids):
                scores[cid] += 1.0 / (k + rank + 1)
        # Sort: high score first, deterministic tie-break by id
        return sorted(scores.items(), key=lambda x: (-x[1], x[0]))

    # --------------------------------------------------------------------- #
    # 5) Fetch the chunk text                                               #
    # --------------------------------------------------------------------- #
    def _hydrate_chunks(self, ids: List[str]) -> List[Dict]:
        if not ids:
            return []
        cypher = (
            "MATCH (c:Chunk) WHERE c.id IN $ids "
            "RETURN c.id AS id, c.text AS text"
        )
        with self.driver.session() as s:
            rows = s.run(cypher, ids=ids).data()
        lookup = {r["id"]: r["text"] for r in rows}
        return [{"id": cid, "text": lookup[cid]} for cid in ids if cid in lookup]

    # --------------------------------------------------------------------- #
    # Fallback entity tagger                                                #
    # --------------------------------------------------------------------- #
    @staticmethod
    def _naive_capitalised_np(text: str) -> List[str]:
        """Very small regex-based heuristic: sequences of ≥2 capitalised words."""
        import re

        patt = re.compile(r"\b([A-Z][\w\-]*(?:\s+[A-Z][\w\-]*)+)\b")
        return patt.findall(text)
    
    def get_supporting_triples(self, chunk_ids: List[str]) -> List[Dict]:
        if not chunk_ids:
            return []

        cypher = """
        MATCH (c:Chunk)
        WHERE c.id IN $chunk_ids
        OPTIONAL MATCH (e1)-[r]->(e2)
        WHERE (e1)-[:MENTIONED_IN]->(c) AND (e2)-[:MENTIONED_IN]->(c)
        RETURN DISTINCT e1.name AS subject, type(r) AS relation, e2.name AS object
        LIMIT 30
        """
        with self.driver.session() as session:
            rows = session.run(cypher, chunk_ids=chunk_ids).data()
        return rows
    
        # --------------------------------------------------------------------- #
    # Public API
    # --------------------------------------------------------------------- #
    def invoke(self, query: str, entities: Optional[List[str]] = None) -> Dict[str, List[Dict]]:
        dense_ids, _ = self._dense_seeds(query)

        expansion_ids = (
            self._expand_parents(dense_ids, self.expand_parents)
            | self._expand_neighbors(dense_ids, self.expand_neighbors)
        )

        if entities is None:
            entities = self.entity_tagger(query)

        kg_ids = self.kg_probe_from_entities(entities)

        fused = self._rrf_fuse(
            [
                ("dense", dense_ids),
                ("expand", list(expansion_ids)),
                ("kg", kg_ids),
            ]
        )

        final_ids = [cid for cid, _ in fused[: self.max_chunks]]
        chunks = self._hydrate_chunks(final_ids)
        triples = self.get_supporting_triples(final_ids)

        return {
            "chunks": chunks,
            "triples": triples,
        }


In [10]:
class SmartRAGPipeline:
    def __init__(self, hybrid_retriever, triple_retriever, llm):
        self.hybrid = hybrid_retriever
        self.triples = triple_retriever
        self.llm = llm

    def run(self, query: str) -> str:
        entities = extract_entities(query)
        intent = classify_intent(query)

        if intent == "fact":
            # Mode A: KG-only
            triples = self.triples.get_triples(entities)
            triple_text = self._format_triples(triples)
            prompt = f"""Answer using the facts below. If the answer isn't found, say "I don't know."\n\n{triple_text}\n\nQ: {query}\nA:"""
            return self.llm(prompt)

        else:
            # Mode B: Chunks + triples
            chunks = self.hybrid.invoke(query, entities=entities)
            chunk_ids = [c["id"] for c in chunks]
            chunk_text = "\n\n---\n\n".join([c["text"] for c in chunks])
            triples = self.hybrid.get_supporting_triples(chunk_ids)
            triple_text = self._format_triples(triples)

            prompt = f"""Use the facts and context to answer. If not answerable, say "I don't know."\n\nFacts:\n{triple_text}\n\nContext:\n{chunk_text}\n\nQ: {query}\nA:"""
            return self.llm(prompt)

    def _format_triples(self, triples: List[Dict]) -> str:
        return "\n".join(f"- ({t['subject']}, {t['relation']}, {t['object']})" for t in triples)


In [16]:
retriever = HybridRetriever(
    driver=driver,
    embedder=embedder,
    index_name="chunk-index",
)

In [26]:
query = "List all iPhones released by Apple in 2022"

context = retriever.invoke(query)

In [27]:
context

{'chunks': [{'id': 'chunk_0035',
   'text': 'COVID-19\nThe COVID-19 pandemic has had, and continues to have, a significant impact around the world, prompting governments and businesses to take unprecedented\nmeasures, such as restrictions on travel and business operations, temporary closures of businesses, and quarantine and shelter-in-place orders. The COVID-19\npandemic has at times significantly curtailed global economic activity and caused significant volatility and disruption in global financial markets. The COVID-19\npandemic and the measures taken by many countries in response have affected and could in the future materially impact the Company’s business, results of\noperations and financial condition.\nCertain of the Company’s outsourcing partners, component suppliers and logistical service providers have experienced disruptions during the COVID-19\npandemic, resulting in supply shortages. Similar disruptions could occur in the future.\nProducts and Services Performance\nThe fo

In [28]:
context['chunks']

[{'id': 'chunk_0035',
  'text': 'COVID-19\nThe COVID-19 pandemic has had, and continues to have, a significant impact around the world, prompting governments and businesses to take unprecedented\nmeasures, such as restrictions on travel and business operations, temporary closures of businesses, and quarantine and shelter-in-place orders. The COVID-19\npandemic has at times significantly curtailed global economic activity and caused significant volatility and disruption in global financial markets. The COVID-19\npandemic and the measures taken by many countries in response have affected and could in the future materially impact the Company’s business, results of\noperations and financial condition.\nCertain of the Company’s outsourcing partners, component suppliers and logistical service providers have experienced disruptions during the COVID-19\npandemic, resulting in supply shortages. Similar disruptions could occur in the future.\nProducts and Services Performance\nThe following tabl

In [29]:
context['triples']

[{'subject': 'FinancialItem',
  'relation': 'RELATEDTOPRODUCT',
  'object': 'Product'},
 {'subject': 'FinancialItem',
  'relation': 'COVERSPERIOD',
  'object': 'DatePeriod'},
 {'subject': 'Product', 'relation': 'HASLINEITEM', 'object': 'FinancialItem'},
 {'subject': 'Company',
  'relation': 'HASSTATEMENT',
  'object': 'FinancialReport'},
 {'subject': 'FinancialReport',
  'relation': 'HASLINEITEM',
  'object': 'FinancialItem'},
 {'subject': 'FinancialReport',
  'relation': 'COVERSPERIOD',
  'object': 'DatePeriod'},
 {'subject': 'Apple Inc.', 'relation': 'HASLINEITEM', 'object': 'Q3 2022'},
 {'subject': 'Apple Inc.',
  'relation': 'HASLINEITEM',
  'object': 'Dividends Paid'},
 {'subject': 'Apple Inc.',
  'relation': 'HASLINEITEM',
  'object': 'Stock Repurchase'},
 {'subject': 'Apple Inc.',
  'relation': 'HASSTATEMENT',
  'object': 'Q3 2022 Form 10-Q'},
 {'subject': 'Apple Inc.',
  'relation': 'REPORTSON',
  'object': 'Q3 2022 Form 10-Q'},
 {'subject': 'Apple Inc.',
  'relation': 'RELATED

#### RAG PIPELINE

In [None]:
def format_triples(triples: List[Dict]) -> str:
    return "\n".join(f"- ({t['subject']}, {t['relation']}, {t['object']})" for t in triples)

In [None]:
query = "How was the apple sales in 2022?"
entities = extract_entities(query)
context = retriever.invoke(query, entities)

chunk_text = "\n\n---\n\n".join([c["text"] for c in context['chunks']])
chunk_ids = [c["id"] for c in context['chunks']]
triples = context['triples']
triple_text = format_triples(triples)

In [None]:
prompt = f"""Use the facts and context to answer. If not answerable, say "I don't know."\n\nFacts:\n{triple_text}\n\nContext:\n{chunk_text}\n\nQ: {query}\nA:"""
print(prompt)

Use the facts and context to answer. If not answerable, say "I don't know."

Facts:
- (Apple, HASSTATEMENT, Net Sales)
- (Net Sales, HASLINEITEM, Services)
- (Net Sales, HASLINEITEM, Wearables, Home and Accessories)
- (Net Sales, HASLINEITEM, iPad)
- (Net Sales, HASLINEITEM, Mac)
- (Net Sales, HASLINEITEM, iPhone)
- (Total net sales, HASVALUE, 82,959)
- (Apple Inc., HASSTATEMENT, Q3 2022 Form 10-Q)
- (Q3 2022 Form 10-Q, HASLINEITEM, Revenue Recognized)
- (Q3 2022 Form 10-Q, HASLINEITEM, Deferred Revenue)
- (FinancialItem, RELATEDTOSERVICE, Service)
- (FinancialItem, RELATEDTOPRODUCT, Product)
- (Product, HASLINEITEM, FinancialItem)
- (FinancialReport, HASSTATEMENT, FinancialStatement)
- (FinancialReport, HASLINEITEM, FinancialItem)
- (iPhone, RELATEDTOSERVICE, Services)
- (iPhone, HASVALUE, 28,669)
- (Mac, HASVALUE, 26,012)
- (Wearables, Home and Accessories, HASVALUE, 22,118)
- (Services, HASVALUE, 19,604)
- (Apple Inc., HASLINEITEM, Services)
- (Apple Inc., RELATEDTOSERVICE, Services

In [None]:
response = llm(prompt)
print(response)

content="In 2022, Apple's total net sales for the third quarter amounted to $82,959 million, showing a 2% increase compared to $81,434 million in the same quarter of the previous year. The iPhone generated the highest sales at $28,669 million, followed by the Mac at $26,012 million, services at $19,604 million, and wearables, home and accessories at $22,118 million. However, both iPad and wearables, home and accessories experienced decreased net sales compared to the previous year. Overall, services sales increased by 12%, indicating strong growth in that segment." additional_kwargs={'refusal': None} response_metadata={'token_usage': {'completion_tokens': 129, 'prompt_tokens': 3306, 'total_tokens': 3435, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}}, 'model_name': 'gpt-4o-mini-2024-07-18', 'system_fingerprint': 'fp_dbaca60df0', '

#### Problems

-  
- 

-

## Retriever = Approach - 2

In [None]:
from typing import List, Dict, Optional, Any
import numpy as np
from neo4j import GraphDatabase
from langchain.embeddings.base import Embeddings
from openai import OpenAI
import json
from enum import Enum, auto

class QueryType(Enum):
    ENTITY_FOCUSED = auto()  # "Who is Tim Cook?"
    RELATIONSHIP = auto()    # "What products did Apple release in 2023?"
    COMPARISON = auto()      # "Compare iPhone 13 and iPhone 14"
    CONTEXTUAL = auto()      # "Explain how Apple's supply chain works"
    TEMPORAL = auto()        # "What happened to Apple stock after the iPhone 15 launch?"

class EnhancedGraphRetriever:
    """
    An adaptive retriever that leverages both graph structure and semantic similarity,
    with dynamic weighting based on query understanding.
    """
    
    def __init__(
        self,
        driver: GraphDatabase.driver,
        embedder: Embeddings,
        llm_client: OpenAI,
        *,
        index_name: str,
        max_chunks: int = 12,
        max_triples: int = 20,
    ):
        self.driver = driver
        self.embedder = embedder
        self.llm_client = llm_client
        self.index_name = index_name
        self.max_chunks = max_chunks
        self.max_triples = max_triples
    
    def _classify_query(self, query: str) -> Dict[str, Any]:
        """Analyze query to determine type, entities, and relevant properties"""
        tools = [{
            "type": "function",
            "function": {
                "name": "analyze_query",
                "description": "Analyzes a search query for RAG retrieval",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "query_type": {
                            "type": "string",
                            "enum": [qt.name for qt in QueryType],
                            "description": "The type of query"
                        },
                        "entities": {
                            "type": "array",
                            "items": {"type": "string"},
                            "description": "Named entities in the query"
                        },
                        "relationships": {
                            "type": "array",
                            "items": {"type": "string"},
                            "description": "Types of relationships to look for"
                        },
                        "time_focus": {
                            "type": "string",
                            "description": "Any temporal focus in the query (e.g., '2023', 'recent', 'after launch')"
                        },
                        "attributes": {
                            "type": "array",
                            "items": {"type": "string"},
                            "description": "Key attributes or properties mentioned"
                        }
                    },
                    "required": ["query_type", "entities"]
                }
            }
        }]

        response = self.llm_client.chat.completions.create(
            model="gpt-4o",
            messages=[
                {"role": "system", "content": "Analyze the search query to optimize knowledge graph retrieval."},
                {"role": "user", "content": query}
            ],
            tools=tools,
            tool_choice={"type": "function", "function": {"name": "analyze_query"}}
        )

        args = response.choices[0].message.tool_calls[0].function.arguments
        return json.loads(args)
    
    def _retrieve_entity_context(self, entities: List[str], limit: int = 10) -> List[Dict]:
        """Get chunks that provide context about specified entities"""
        cypher = """
        MATCH (e)
        WHERE toLower(e.name) IN $entities
        MATCH (e)-[:MENTIONED_IN]->(c:Chunk)
        WITH c, count(DISTINCT e) AS entity_matches
        ORDER BY entity_matches DESC
        LIMIT $limit
        RETURN c.id AS id, c.text AS text, entity_matches
        """
        
        with self.driver.session() as session:
            rows = session.run(
                cypher,
                entities=[e.lower() for e in entities],
                limit=limit
            ).data()
        
        return rows
    
    def _retrieve_relationship_context(
        self, 
        entities: List[str], 
        relationships: List[str] = None,
        limit: int = 10
    ) -> List[Dict]:
        """Get chunks describing relationships between entities"""
        
        # Base query when no specific relationships are provided
        if not relationships or len(relationships) == 0:
            cypher = """
            MATCH (e1)-[r]->(e2)
            WHERE toLower(e1.name) IN $entities OR toLower(e2.name) IN $entities
            MATCH (e1)-[:MENTIONED_IN]->(c:Chunk)<-[:MENTIONED_IN]-(e2)
            RETURN DISTINCT c.id AS id, c.text AS text, 
                   count(DISTINCT r) AS rel_count
            ORDER BY rel_count DESC
            LIMIT $limit
            """
            params = {
                "entities": [e.lower() for e in entities],
                "limit": limit
            }
        else:
            # Query with specific relationship types
            cypher = """
            MATCH (e1)-[r]->(e2)
            WHERE (toLower(e1.name) IN $entities OR toLower(e2.name) IN $entities)
                  AND type(r) IN $relationships
            MATCH (e1)-[:MENTIONED_IN]->(c:Chunk)<-[:MENTIONED_IN]-(e2)
            RETURN DISTINCT c.id AS id, c.text AS text, 
                   count(DISTINCT r) AS rel_count
            ORDER BY rel_count DESC
            LIMIT $limit
            """
            params = {
                "entities": [e.lower() for e in entities],
                "relationships": [r.upper().replace(' ', '_') for r in relationships],
                "limit": limit
            }
            
        with self.driver.session() as session:
            rows = session.run(cypher, **params).data()
            
        return rows
    
    def _retrieve_temporal_context(
        self, 
        entities: List[str],
        time_focus: str,
        limit: int = 10
    ) -> List[Dict]:
        """Get chunks with temporal focus"""
        # First try to find chunks with explicit time references
        cypher = """
        MATCH (e)
        WHERE toLower(e.name) IN $entities
        MATCH (e)-[:MENTIONED_IN]->(c:Chunk)
        WHERE c.text CONTAINS $time_focus
        RETURN c.id AS id, c.text AS text, 1.0 AS score
        LIMIT $limit
        """
        
        with self.driver.session() as session:
            rows = session.run(
                cypher,
                entities=[e.lower() for e in entities],
                time_focus=time_focus,
                limit=limit
            ).data()
            
        return rows
    
    def _retrieve_dense_vectors(self, query: str, limit: int = 10) -> List[Dict]:
        """Semantic search using embeddings"""
        embedding = self.embedder.embed_query(query)
        cypher = """
        CALL db.index.vector.queryNodes($index_name, $k, $embed) 
        YIELD node, score 
        RETURN node.id AS id, node.text AS text, score
        ORDER BY score DESC
        """
        
        with self.driver.session() as session:
            rows = session.run(
                cypher,
                index_name=self.index_name,
                k=limit,
                embed=embedding
            ).data()
            
        return rows
    
    def _retrieve_comparison_context(
        self,
        entities: List[str],
        attributes: List[str] = None,
        limit: int = 10
    ) -> List[Dict]:
        """Get chunks comparing entities or discussing specific attributes"""
        # Base query when no attributes specified
        if not attributes or len(attributes) == 0:
            cypher = """
            MATCH (e1)-[:MENTIONED_IN]->(c:Chunk)<-[:MENTIONED_IN]-(e2)
            WHERE toLower(e1.name) IN $entities AND toLower(e2.name) IN $entities 
                  AND e1 <> e2
            RETURN c.id AS id, c.text AS text, 
                   count(DISTINCT e1) + count(DISTINCT e2) AS entity_count
            ORDER BY entity_count DESC
            LIMIT $limit
            """
            params = {
                "entities": [e.lower() for e in entities],
                "limit": limit
            }
        else:
            # Try to find specific attribute mentions
            attribute_clauses = []
            for attr in attributes:
                attribute_clauses.append(f"c.text CONTAINS '{attr}'")
            
            attr_filter = " OR ".join(attribute_clauses)
            
            cypher = f"""
            MATCH (e1)-[:MENTIONED_IN]->(c:Chunk)<-[:MENTIONED_IN]-(e2)
            WHERE toLower(e1.name) IN $entities AND toLower(e2.name) IN $entities 
                  AND e1 <> e2
                  AND ({attr_filter})
            RETURN c.id AS id, c.text AS text, 
                   count(DISTINCT e1) + count(DISTINCT e2) AS entity_count
            ORDER BY entity_count DESC
            LIMIT $limit
            """
            params = {
                "entities": [e.lower() for e in entities],
                "limit": limit
            }
            
        with self.driver.session() as session:
            rows = session.run(cypher, **params).data()
            
        return rows
    
    def get_relevant_triples(self, entities: List[str], relationships: List[str] = None) -> List[Dict]:
        """Retrieve knowledge graph triples directly"""
        if not relationships or len(relationships) == 0:
            cypher = """
            MATCH (s)-[r]->(o)
            WHERE toLower(s.name) IN $entities OR toLower(o.name) IN $entities
            RETURN DISTINCT s.name AS subject, type(r) AS relation, o.name AS object
            LIMIT $limit
            """
            params = {
                "entities": [e.lower() for e in entities],
                "limit": self.max_triples
            }
        else:
            cypher = """
            MATCH (s)-[r]->(o)
            WHERE (toLower(s.name) IN $entities OR toLower(o.name) IN $entities)
                  AND type(r) IN $relationships
            RETURN DISTINCT s.name AS subject, type(r) AS relation, o.name AS object
            LIMIT $limit
            """
            params = {
                "entities": [e.lower() for e in entities],
                "relationships": [r.upper().replace(' ', '_') for r in relationships],
                "limit": self.max_triples
            }
            
        with self.driver.session() as session:
            rows = session.run(cypher, **params).data()
            
        return rows
    
    def _merge_and_deduplicate(self, chunk_sets: List[List[Dict]]) -> List[Dict]:
        """Merge multiple chunk sets with scoring and deduplication"""
        id_to_chunk = {}
        id_to_score = {}
        
        # Process each chunk set with different weights based on position
        for i, chunk_set in enumerate(chunk_sets):
            # Weight earlier sets higher (they're more precisely matched to query type)
            weight = 1.0 / (i + 1)
            
            for chunk in chunk_set:
                chunk_id = chunk["id"]
                
                # Store the chunk if we haven't seen it
                if chunk_id not in id_to_chunk:
                    id_to_chunk[chunk_id] = chunk
                    id_to_score[chunk_id] = 0
                
                # Add to the score (different chunk sets may score differently)
                if "score" in chunk:
                    id_to_score[chunk_id] += chunk["score"] * weight
                elif "entity_matches" in chunk:
                    id_to_score[chunk_id] += chunk["entity_matches"] * weight
                elif "rel_count" in chunk:
                    id_to_score[chunk_id] += chunk["rel_count"] * weight
                elif "entity_count" in chunk:
                    id_to_score[chunk_id] += chunk["entity_count"] * weight
                else:
                    id_to_score[chunk_id] += weight
        
        # Sort by score and return top chunks
        sorted_ids = sorted(id_to_score.keys(), key=lambda x: -id_to_score[x])
        return [id_to_chunk[chunk_id] for chunk_id in sorted_ids[:self.max_chunks]]
    
    def invoke(self, query: str) -> Dict[str, Any]:
        """Main retrieval method adapting to query type"""
        # Analyze the query
        query_analysis = self._classify_query(query)
        query_type = QueryType[query_analysis["query_type"]]
        entities = query_analysis["entities"]
        
        # Initialize chunk lists based on query type
        chunks_from_specialized = []
        chunks_from_dense = self._retrieve_dense_vectors(query)
        
        # Apply specialized retrievers based on query type
        if query_type == QueryType.ENTITY_FOCUSED:
            chunks_from_specialized = self._retrieve_entity_context(entities)
            
        elif query_type == QueryType.RELATIONSHIP:
            relationships = query_analysis.get("relationships", [])
            chunks_from_specialized = self._retrieve_relationship_context(
                entities, relationships
            )
            
        elif query_type == QueryType.COMPARISON:
            attributes = query_analysis.get("attributes", [])
            chunks_from_specialized = self._retrieve_comparison_context(
                entities, attributes
            )
            
        elif query_type == QueryType.TEMPORAL:
            time_focus = query_analysis.get("time_focus", "")
            if time_focus:
                chunks_from_specialized = self._retrieve_temporal_context(
                    entities, time_focus
                )
            
        # For contextual queries, we rely more on dense retrieval
        # But for all types, we merge specialized and dense results
        merged_chunks = self._merge_and_deduplicate([
            chunks_from_specialized,
            chunks_from_dense
        ])
        
        # Get relevant triples
        relationships = query_analysis.get("relationships", [])
        triples = self.get_relevant_triples(entities, relationships)
        
        # Return both chunks and triples
        return {
            "chunks": merged_chunks,
            "triples": triples,
            "query_analysis": query_analysis
        }

In [None]:
# Setup
llm_client = OpenAI(api_key=api_key)

# Create retriever
retriever = EnhancedGraphRetriever(
    driver=driver,
    embedder=embedder,
    llm_client=llm_client,
    index_name="chunk-index"
)

# Example queries
result1 = retriever.invoke("Who is Tim Cook?")
result2 = retriever.invoke("What products did Apple release in 2023?")
result3 = retriever.invoke("Compare iPhone 13 and iPhone 14 battery life")
result4 = retriever.invoke("How does Apple's supply chain work?")

In [None]:
result1.keys()

dict_keys(['chunks', 'triples', 'query_analysis'])

In [None]:
print(result1['chunks'])
print(result1['triples'])
print(result1['query_analysis'])

[{'id': 'chunk_0057', 'text': 'Exhibit 31.1\nCERTIFICATION\nI, Timothy D. Cook, certify that:\n1.\nI have reviewed this quarterly report on Form 10-Q of Apple Inc.;\n2.\nBased on my knowledge, this report does not contain any untrue statement of a material fact or omit to state a material fact necessary to make the\nstatements made, in light of the circumstances under which such statements were made, not misleading with respect to the period covered by this report;\n3.\nBased on my knowledge, the financial statements, and other financial information included in this report, fairly present in all material respects the financial\ncondition, results of operations and cash flows of the Registrant as of, and for, the periods presented in this report;\n4.\nThe Registrant’s other certifying officer(s) and I are responsible for establishing and maintaining disclosure controls and procedures (as defined in Exchange\nAct Rules 13a-15(e) and 15d-15(e)) and internal control over financial reportin

In [None]:
print(result2['chunks'])
print(result2['triples'])
print(result2['query_analysis'])

[{'id': 'chunk_0015', 'text': 'Note 2 – Revenue\nNet sales disaggregated by significant products and services for the three- and nine-month periods ended June 25, 2022 and June\xa026, 2021 were as follows (in\nmillions):\nThree Months Ended\nNine Months Ended\nJune 25,\n2022\nJune 26,\n2021\nJune 25,\n2022\nJune 26,\n2021\niPhone\n$\n40,665\xa0\n$\n39,570\xa0\n$\n162,863\xa0\n$\n153,105\xa0\nMac  \n7,382\xa0\n8,235\xa0\n28,669\xa0\n26,012\xa0\niPad\n7,224\xa0\n7,368\xa0\n22,118\xa0\n23,610\xa0\nWearables, Home and Accessories \n8,084\xa0\n8,775\xa0\n31,591\xa0\n29,582\xa0\nServices \n19,604\xa0\n17,486\xa0\n58,941\xa0\n50,148\xa0\nTotal net sales \n$\n82,959\xa0\n$\n81,434\xa0\n$\n304,182\xa0\n$\n282,457\xa0\n(1)\nProducts net sales include amortization of the deferred value of unspecified software upgrade rights, which are bundled in the sales price of the respective\nproduct.\n(2)\nWearables, Home and Accessories net sales include sales of AirPods , Apple TV , Apple Watch , Beats  pr