In [1]:
from dotenv import load_dotenv
from neo4j import GraphDatabase
import os
## Load environment variables from .env file
load_dotenv()

api_key = os.getenv("OPENAI_API_KEY")
uri = os.getenv("NEO4J_URI")
user = os.getenv("NEO4J_USERNAME")
password = os.getenv("NEO4J_PASSWORD")
AUTH = (user, password)

In [3]:
from neo4j import GraphDatabase


# --- prerequisites ------------------------------------------------------
driver   = GraphDatabase.driver(uri, auth=(user, password))


## Retriever = Approach - 1

In [35]:
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
embedder = OpenAIEmbeddings(api_key=api_key, model="text-embedding-3-large")
llm = ChatOpenAI(model="gpt-4o-mini", openai_api_key=api_key)

In [5]:
from openai import OpenAI
import json

client = OpenAI(api_key=api_key)  # ideally os.getenv()

# -------------------- Entity Extraction -------------------- #
def extract_entities(query: str) -> list[str]:
    tools = [
        {
            "type": "function",
            "function": {
                "name": "extract_entities",
                "description": "Extracts all proper nouns or named entities from a query.",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "entities": {
                            "type": "array",
                            "items": {"type": "string"},
                            "description": "List of named entities"
                        }
                    },
                    "required": ["entities"]
                }
            }
        }
    ]

    response = client.chat.completions.create(
        model="gpt-4o",
        messages=[
            {"role": "system", "content": "Extract named entities from the user question."},
            {"role": "user", "content": query}
        ],
        tools=tools,
        tool_choice={"type": "function", "function": {"name": "extract_entities"}}
    )

    args = response.choices[0].message.tool_calls[0].function.arguments
    return json.loads(args)["entities"]


In [6]:
# -------------------- Query Intent Classification -------------------- #
def classify_intent(query: str) -> str:
    tools = [
        {
            "type": "function",
            "function": {
                "name": "classify_intent",
                "description": "Classifies the query intent as fact or context.",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "intent": {
                            "type": "string",
                            "enum": ["fact", "context"],
                            "description": "Is it a fact query (asks for a list or factual triple) or a context query (asks for explanation)?"
                        }
                    },
                    "required": ["intent"]
                }
            }
        }
    ]

    response = client.chat.completions.create(
        model="gpt-4o",
        messages=[
            {"role": "system", "content": "Classify the user query as 'fact' or 'context'."},
            {"role": "user", "content": query}
        ],
        tools=tools,
        tool_choice={"type": "function", "function": {"name": "classify_intent"}}
    )

    args = response.choices[0].message.tool_calls[0].function.arguments
    return json.loads(args)["intent"]


In [7]:
query = "List all iPhones released by Apple in 2022"

entities = extract_entities(query)
intent = classify_intent(query)

print("Entities:", entities)
print("Intent:", intent)

Entities: ['iPhones', 'Apple', '2022']
Intent: fact


In [8]:
from typing import List, Dict, Tuple, Callable, Optional, Sequence

class TripleRetriever:
    def __init__(self, driver: GraphDatabase.driver):
        self.driver = driver

    def get_triples(self, query_entities: List[str], limit: int = 20) -> List[Dict]:
        if not query_entities:
            return []

        cypher = """
        MATCH (s)-[r]->(o)
        WHERE toLower(s.name) IN $entities OR toLower(o.name) IN $entities
        RETURN s.name AS subject, type(r) AS relation, o.name AS object
        LIMIT $limit
        """
        with self.driver.session() as session:
            rows = session.run(
                cypher,
                entities=[e.lower() for e in query_entities],
                limit=limit,
            ).data()
        return rows  # List of {"subject": ..., "relation": ..., "object": ...}


In [9]:
# hybrid_retriever.py
from __future__ import annotations
from typing import List, Dict, Tuple, Callable, Optional, Sequence
from collections import defaultdict

import numpy as np
from neo4j import GraphDatabase
from langchain.embeddings.base import Embeddings


class HybridRetriever:
    """
    A *single-entry-point* retriever that combines

    1. **Dense semantic search** over `(:Chunk)` vectors stored in Neo4j
    2. **Structural expansion** (PARENT_OF, NEXT_CHUNK) to keep narrative context
    3. **Knowledge-graph probing** via light-weight NER →  `(:Entity)-[:MENTIONED_IN]->(:Chunk)`
    4. **Reciprocal-Rank Fusion** to merge the three candidate sets

    The end product is an *ordered* list of chunk dictionaries ready to be
    concatenated into your RAG prompt template.

    Parameters
    ----------
    driver :
        An active `neo4j.GraphDatabase.driver`.
    embedder :
        An object that implements `embed_query(text) -> List[float]`
        (e.g. `langchain_openai.OpenAIEmbeddings`).
    index_name :
        Name of the Neo4j vector index created by `create_vector_index`.
    entity_tagger :
        Callable that returns a list of surface strings found in a query.
        Default is a mini-regex fallback that picks capitalised noun-phrases;
        in production you will probably inject spaCy, HuggingFace, or an
        OpenAI function-calling wrapper.
    top_k_dense :
        How many dense seeds to return.
    expand_parents / expand_neighbors :
        How many *per-seed* hierarchical and adjacent chunks to pull in
        (0 disables the respective expansion).
    kg_chunk_limit :
        Per-entity cap when pulling chunks through the KG.
    max_chunks :
        Final cap after fusion/rerank – keeps your context window bounded.
    """

    def __init__(
        self,
        driver: GraphDatabase.driver,
        embedder: Embeddings,
        *,
        index_name: str,
        entity_tagger: Optional[Callable[[str], Sequence[str]]] = None,
        top_k_dense: int = 5,
        expand_parents: int = 1,
        expand_neighbors: int = 1,
        kg_chunk_limit: int = 3,
        max_chunks: int = 12,
    ) -> None:
        self.driver = driver
        self.embedder = embedder
        self.index_name = index_name
        self.top_k_dense = top_k_dense
        self.expand_parents = expand_parents
        self.expand_neighbors = expand_neighbors
        self.kg_chunk_limit = kg_chunk_limit
        self.max_chunks = max_chunks
        self.entity_tagger = entity_tagger or self._naive_capitalised_np

    # --------------------------------------------------------------------- #
    # 1) Dense search (vector index)                                         #
    # --------------------------------------------------------------------- #
    def _dense_seeds(self, query: str) -> Tuple[List[str], List[float]]:
        embedding = self.embedder.embed_query(query)
        cypher = (
            "CALL db.index.vector.queryNodes($index_name, $k, $embed) "
            "YIELD node, score "
            "RETURN node.id AS id, score "
            "ORDER BY score DESC"
        )
        with self.driver.session() as s:
            res = s.run(
                cypher,
                index_name=self.index_name,
                k=self.top_k_dense,
                embed=embedding,
            )
            rows = res.data()
        ids = [r["id"] for r in rows]
        scores = [r["score"] for r in rows]
        return ids, scores

    # --------------------------------------------------------------------- #
    # 2) Context expansion via PARENT_OF / NEXT_CHUNK                       #
    # --------------------------------------------------------------------- #
    def _expand_parents(self, chunk_ids: List[str], hops: int) -> set[str]:
        if hops == 0 or not chunk_ids:
            return set()
        cypher = f"""
            MATCH (c:Chunk)<-[:PARENT_OF*1..{hops}]-(p)
            WHERE c.id IN $ids
            RETURN DISTINCT p.id AS id
        """
        with self.driver.session() as s:
            rows = s.run(cypher, ids=chunk_ids).data()
        return {r["id"] for r in rows}


    def _expand_neighbors(self, chunk_ids: List[str], hops: int) -> set[str]:
        if hops == 0 or not chunk_ids:
            return set()
        cypher = f"""
            MATCH (c1:Chunk)-[:NEXT_CHUNK*1..{hops}]->(c2:Chunk)
            WHERE c1.id IN $ids
            RETURN DISTINCT c2.id AS id
        """
        with self.driver.session() as s:
            rows = s.run(cypher, ids=chunk_ids).data()
        return {r["id"] for r in rows}


    # --------------------------------------------------------------------- #
    # 3) Knowledge-graph probe                                              #
    # --------------------------------------------------------------------- #
    def kg_probe_from_entities(self, entities: List[str]) -> List[str]:
        if not entities:
            return []
        cypher = (
            "MATCH (e)-[:MENTIONED_IN]->(c:Chunk) "
            "WHERE toLower(e.name) IN $ents "
            "RETURN DISTINCT c.id AS id "
            "LIMIT $lim"
        )
        with self.driver.session() as s:
            rows = s.run(
                cypher,
                ents=[e.lower() for e in entities],
                lim=self.kg_chunk_limit * len(entities),
            ).data()
        return [r["id"] for r in rows]


    # --------------------------------------------------------------------- #
    # 4) Reciprocal-Rank-Fusion                                             #
    # --------------------------------------------------------------------- #
    def _rrf_fuse(self, runs: List[Tuple[str, List[str]]], k: int = 60) -> List[Tuple[str, float]]:
        """
        Each `runs[i]` is (label, [id0, id1, ...]).

        RRF score = 1 / (k + rank).  Higher is better.
        """
        scores = defaultdict(float)
        for _, ids in runs:
            for rank, cid in enumerate(ids):
                scores[cid] += 1.0 / (k + rank + 1)
        # Sort: high score first, deterministic tie-break by id
        return sorted(scores.items(), key=lambda x: (-x[1], x[0]))

    # --------------------------------------------------------------------- #
    # 5) Fetch the chunk text                                               #
    # --------------------------------------------------------------------- #
    def _hydrate_chunks(self, ids: List[str]) -> List[Dict]:
        if not ids:
            return []
        cypher = (
            "MATCH (c:Chunk) WHERE c.id IN $ids "
            "RETURN c.id AS id, c.text AS text"
        )
        with self.driver.session() as s:
            rows = s.run(cypher, ids=ids).data()
        lookup = {r["id"]: r["text"] for r in rows}
        return [{"id": cid, "text": lookup[cid]} for cid in ids if cid in lookup]

    # --------------------------------------------------------------------- #
    # Fallback entity tagger                                                #
    # --------------------------------------------------------------------- #
    @staticmethod
    def _naive_capitalised_np(text: str) -> List[str]:
        """Very small regex-based heuristic: sequences of ≥2 capitalised words."""
        import re

        patt = re.compile(r"\b([A-Z][\w\-]*(?:\s+[A-Z][\w\-]*)+)\b")
        return patt.findall(text)
    
    def get_supporting_triples(self, chunk_ids: List[str]) -> List[Dict]:
        if not chunk_ids:
            return []

        cypher = """
        MATCH (c:Chunk)
        WHERE c.id IN $chunk_ids
        OPTIONAL MATCH (e1)-[r]->(e2)
        WHERE (e1)-[:MENTIONED_IN]->(c) AND (e2)-[:MENTIONED_IN]->(c)
        RETURN DISTINCT e1.name AS subject, type(r) AS relation, e2.name AS object
        LIMIT 30
        """
        with self.driver.session() as session:
            rows = session.run(cypher, chunk_ids=chunk_ids).data()
        return rows
    
        # --------------------------------------------------------------------- #
    # Public API
    # --------------------------------------------------------------------- #
    def invoke(self, query: str, entities: Optional[List[str]] = None) -> Dict[str, List[Dict]]:
        dense_ids, _ = self._dense_seeds(query)

        expansion_ids = (
            self._expand_parents(dense_ids, self.expand_parents)
            | self._expand_neighbors(dense_ids, self.expand_neighbors)
        )

        if entities is None:
            entities = self.entity_tagger(query)

        kg_ids = self.kg_probe_from_entities(entities)

        fused = self._rrf_fuse(
            [
                ("dense", dense_ids),
                ("expand", list(expansion_ids)),
                ("kg", kg_ids),
            ]
        )

        final_ids = [cid for cid, _ in fused[: self.max_chunks]]
        chunks = self._hydrate_chunks(final_ids)
        triples = self.get_supporting_triples(final_ids)

        return {
            "chunks": chunks,
            "triples": triples,
        }


In [10]:
class SmartRAGPipeline:
    def __init__(self, hybrid_retriever, triple_retriever, llm):
        self.hybrid = hybrid_retriever
        self.triples = triple_retriever
        self.llm = llm

    def run(self, query: str) -> str:
        entities = extract_entities(query)
        intent = classify_intent(query)

        if intent == "fact":
            # Mode A: KG-only
            triples = self.triples.get_triples(entities)
            triple_text = self._format_triples(triples)
            prompt = f"""Answer using the facts below. If the answer isn't found, say "I don't know."\n\n{triple_text}\n\nQ: {query}\nA:"""
            return self.llm(prompt)

        else:
            # Mode B: Chunks + triples
            chunks = self.hybrid.invoke(query, entities=entities)
            chunk_ids = [c["id"] for c in chunks]
            chunk_text = "\n\n---\n\n".join([c["text"] for c in chunks])
            triples = self.hybrid.get_supporting_triples(chunk_ids)
            triple_text = self._format_triples(triples)

            prompt = f"""Use the facts and context to answer. If not answerable, say "I don't know."\n\nFacts:\n{triple_text}\n\nContext:\n{chunk_text}\n\nQ: {query}\nA:"""
            return self.llm(prompt)

    def _format_triples(self, triples: List[Dict]) -> str:
        return "\n".join(f"- ({t['subject']}, {t['relation']}, {t['object']})" for t in triples)


In [16]:
retriever = HybridRetriever(
    driver=driver,
    embedder=embedder,
    index_name="chunk-index",
)

In [26]:
query = "List all iPhones released by Apple in 2022"

context = retriever.invoke(query)

In [27]:
context

{'chunks': [{'id': 'chunk_0035',
   'text': 'COVID-19\nThe COVID-19 pandemic has had, and continues to have, a significant impact around the world, prompting governments and businesses to take unprecedented\nmeasures, such as restrictions on travel and business operations, temporary closures of businesses, and quarantine and shelter-in-place orders. The COVID-19\npandemic has at times significantly curtailed global economic activity and caused significant volatility and disruption in global financial markets. The COVID-19\npandemic and the measures taken by many countries in response have affected and could in the future materially impact the Company’s business, results of\noperations and financial condition.\nCertain of the Company’s outsourcing partners, component suppliers and logistical service providers have experienced disruptions during the COVID-19\npandemic, resulting in supply shortages. Similar disruptions could occur in the future.\nProducts and Services Performance\nThe fo

In [28]:
context['chunks']

[{'id': 'chunk_0035',
  'text': 'COVID-19\nThe COVID-19 pandemic has had, and continues to have, a significant impact around the world, prompting governments and businesses to take unprecedented\nmeasures, such as restrictions on travel and business operations, temporary closures of businesses, and quarantine and shelter-in-place orders. The COVID-19\npandemic has at times significantly curtailed global economic activity and caused significant volatility and disruption in global financial markets. The COVID-19\npandemic and the measures taken by many countries in response have affected and could in the future materially impact the Company’s business, results of\noperations and financial condition.\nCertain of the Company’s outsourcing partners, component suppliers and logistical service providers have experienced disruptions during the COVID-19\npandemic, resulting in supply shortages. Similar disruptions could occur in the future.\nProducts and Services Performance\nThe following tabl

In [29]:
context['triples']

[{'subject': 'FinancialItem',
  'relation': 'RELATEDTOPRODUCT',
  'object': 'Product'},
 {'subject': 'FinancialItem',
  'relation': 'COVERSPERIOD',
  'object': 'DatePeriod'},
 {'subject': 'Product', 'relation': 'HASLINEITEM', 'object': 'FinancialItem'},
 {'subject': 'Company',
  'relation': 'HASSTATEMENT',
  'object': 'FinancialReport'},
 {'subject': 'FinancialReport',
  'relation': 'HASLINEITEM',
  'object': 'FinancialItem'},
 {'subject': 'FinancialReport',
  'relation': 'COVERSPERIOD',
  'object': 'DatePeriod'},
 {'subject': 'Apple Inc.', 'relation': 'HASLINEITEM', 'object': 'Q3 2022'},
 {'subject': 'Apple Inc.',
  'relation': 'HASLINEITEM',
  'object': 'Dividends Paid'},
 {'subject': 'Apple Inc.',
  'relation': 'HASLINEITEM',
  'object': 'Stock Repurchase'},
 {'subject': 'Apple Inc.',
  'relation': 'HASSTATEMENT',
  'object': 'Q3 2022 Form 10-Q'},
 {'subject': 'Apple Inc.',
  'relation': 'REPORTSON',
  'object': 'Q3 2022 Form 10-Q'},
 {'subject': 'Apple Inc.',
  'relation': 'RELATED

#### RAG PIPELINE