In [1]:
import nest_asyncio
nest_asyncio.apply()

from llama_index.core import SimpleDirectoryReader, KnowledgeGraphIndex, StorageContext, PropertyGraphIndex
from llama_index.core.graph_stores import SimpleGraphStore

from llama_index.llms.openai import OpenAI
from llama_index.core import Settings
from IPython.display import Markdown, display

from llama_index.readers.docling import DoclingReader
from llama_index.core.node_parser import MarkdownNodeParser
from llama_index.core.node_parser import SentenceWindowNodeParser

import dotenv
dotenv.load_dotenv()

from llama_index.core.schema import TextNode

from llama_index.core.indices.property_graph import (
    SimpleLLMPathExtractor,
    SchemaLLMPathExtractor,
    DynamicLLMPathExtractor,
)

from llama_index.core.graph_stores import SimplePropertyGraphStore, PropertyGraphStore
from llama_index.core.graph_stores.types import LabelledPropertyGraph, EntityNode, Relation, ChunkNode
from uuid import uuid4
from llama_index.core.extractors import KeywordExtractor

In [2]:
company = EntityNode(label="COMPANY", name="Applied Materials")
num_apples = EntityNode(label="VALUE", name="number of apples")
apples = EntityNode(label="VALUE", name="apples")
num_pears = EntityNode(label="VALUE", name="number of pears")
pears = EntityNode(label="VALUE", name="pears")
num_bananas = EntityNode(label="VALUE", name="number of bananas")
bananas = EntityNode(label="VALUE", name="bananas")
num_apricots = EntityNode(label="VALUE", name="number of apricots")
apricots = EntityNode(label="VALUE", name="apricots")
eq_apple_pear = EntityNode(label="EQUATION", name="number of bananas = number of apples - number of pears")
eq_banana_apricot = EntityNode(label="EQUATION", name="number of apricots = number of bananas")
entities = [
    company, 
    apples, pears, bananas, apricots,
    num_apples, num_pears, num_bananas, num_apricots,
    eq_apple_pear, eq_banana_apricot,
]

relations_kpi = [
    Relation(label="HAS_PROPERTY", source_id=company.id, target_id=num_apples.id), 
    Relation(label="HAS_PROPERTY", source_id=company.id, target_id=num_pears.id), 
    Relation(label="HAS_PROPERTY", source_id=company.id, target_id=num_bananas.id),
    Relation(label="HAS_PROPERTY", source_id=company.id, target_id=num_apricots.id),
    # alias relations
    Relation(label="SAME_AS", source_id=bananas.id, target_id=num_bananas.id),
    Relation(label="SAME_AS", source_id=num_bananas.id, target_id=bananas.id),
    Relation(label="SAME_AS", source_id=apples.id, target_id=num_apples.id),
    Relation(label="SAME_AS", source_id=num_apples.id, target_id=apples.id),
    Relation(label="ALIAS_OF", source_id=pears.id, target_id=num_pears.id),
    Relation(label="ALIAS_OF", source_id=num_pears.id, target_id=pears.id),
    Relation(label="ALIAS_OF", source_id=apricots.id, target_id=num_apricots.id),
    Relation(label="ALIAS_OF", source_id=num_apricots.id, target_id=apricots.id),
    # equation bananas = apricots
    Relation(label="HAS_EQUATION", source_id=num_bananas.id, target_id=eq_banana_apricot.id), 
    Relation(label="HAS_EQUATION", source_id=num_apricots.id, target_id=eq_banana_apricot.id),
    Relation(label="DEPENDS_ON", source_id=eq_banana_apricot.id, target_id=bananas.id), 
    Relation(label="DEPENDS_ON", source_id=eq_banana_apricot.id, target_id=apricots.id), 
    # equation bananas = apples - pears
    Relation(label="HAS_EQUATION", source_id=bananas.id, target_id=eq_apple_pear.id),
    Relation(label="HAS_EQUATION", source_id=apples.id, target_id=eq_apple_pear.id),
    Relation(label="HAS_EQUATION", source_id=pears.id, target_id=eq_apple_pear.id),
    Relation(label="DEPENDS_ON", source_id=eq_apple_pear.id, target_id=bananas.id),
    Relation(label="DEPENDS_ON", source_id=eq_apple_pear.id, target_id=apples.id), 
    Relation(label="DEPENDS_ON", source_id=eq_apple_pear.id, target_id=pears.id),
]

R_EQUATION = {"HAS_EQUATION", "COMPUTED_BY", "EQUALS"}
R_ALIAS = {"ALIAS_OF", "SAME_AS"}
R_DEPENDS = {"DEPENDS_ON", "USES", "INPUT"}

kgraph = SimplePropertyGraphStore()
kgraph.upsert_nodes(entities)
kgraph.upsert_relations(relations_kpi)

kgraph.show_jupyter_graph()

GraphWidget(layout=Layout(height='610px', width='100%'))

In [3]:
sentences = [
    TextNode(text="the company Applied Materials (AMAT) has 100 apples"), 
    TextNode(text="the company Applied Materials (AMAT) has 10 pears")
]


kgraph.upsert_llama_nodes(sentences)

kgraph.show_jupyter_graph()

GraphWidget(layout=Layout(height='630px', width='100%'))

In [4]:
kgraph.get_triplets(
    ids=[bananas.id],
    relation_names=list(R_ALIAS)
)

[(EntityNode(label='VALUE', embedding=None, properties={}, name='number of bananas'),
  Relation(label='SAME_AS', source_id='number of bananas', target_id='bananas', properties={}),
  EntityNode(label='VALUE', embedding=None, properties={}, name='bananas')),
 (EntityNode(label='VALUE', embedding=None, properties={}, name='bananas'),
  Relation(label='SAME_AS', source_id='bananas', target_id='number of bananas', properties={}),
  EntityNode(label='VALUE', embedding=None, properties={}, name='number of bananas'))]

In [5]:
# TODO tune extractor prompt so that focus on business metrics words
kw_extractor = KeywordExtractor(llm=None, keywords=6)

relations: list[Relation] = []

name_to_id = {
    node.name.lower(): node.id
    for node in kgraph.graph.get_all_nodes()
    if node.label in {"VALUE", "EQUATION"}
}

def get_alias_nodes(node_id: str) -> set[str]:
    """Get all nodes connected via a relation in R_ALIAS."""
    connected = {node_id}

    for sub, _, obj in kgraph.get_triplets(ids=[node_id], relation_names=list(R_ALIAS)):
        connected.add(sub.id)
        connected.add(obj.id)

    return connected

all_keywords = kw_extractor.extract(sentences)
# TODO can run async
for (sentence, keywords) in zip(sentences, all_keywords):

    keywords_clean = [k.strip().lower() for k in keywords["excerpt_keywords"].split(",")]

    for kw in keywords_clean:
        if kw not in name_to_id.keys():
            continue  # skip keywords that have no VALUE/EQUATION node

        ent_id = name_to_id[kw]
        # add relation to all alias nodes
        for node_id in get_alias_nodes(ent_id):
            print(f"adding relations between {node_id} and {sentence.text}")
            relations.append(
                Relation(
                    source_id=node_id,               # VALUE / EQUATION node id
                    target_id=sentence.id_,               # sentence node id
                    label="HAS_SENTENCE",        # custom relation label
                    properties={"source": "keyword_link"},
                )
            )
        # TODO make sure that when later find sentence through alias, the LLM knows the alias
        # probably by passing information as property

if relations:
    kgraph.upsert_relations(relations)

kgraph.show_jupyter_graph()

100%|██████████| 2/2 [00:00<00:00,  2.49it/s]


adding relations between number of apples and the company Applied Materials (AMAT) has 100 apples
adding relations between apples and the company Applied Materials (AMAT) has 100 apples
adding relations between number of pears and the company Applied Materials (AMAT) has 10 pears
adding relations between pears and the company Applied Materials (AMAT) has 10 pears


GraphWidget(layout=Layout(height='630px', width='100%'))

In [6]:
map = kgraph.get_rel_map(
    graph_nodes=[apricots],
    depth=1,
    limit=30,
    ignore_rels=None,
)

[print(rel) for rel in map]

query_str = "how many apricots does the company have?"
qstr = query_str.lower()
seeds_id = [node_id for name, node_id in name_to_id.items() if name in qstr]

print(name_to_id.items())


(EntityNode(label='VALUE', embedding=None, properties={}, name='apricots'), Relation(label='ALIAS_OF', source_id='apricots', target_id='number of apricots', properties={}), EntityNode(label='VALUE', embedding=None, properties={}, name='number of apricots'))
(EntityNode(label='EQUATION', embedding=None, properties={}, name='number of apricots = number of bananas'), Relation(label='DEPENDS_ON', source_id='number of apricots = number of bananas', target_id='apricots', properties={}), EntityNode(label='VALUE', embedding=None, properties={}, name='apricots'))
(EntityNode(label='VALUE', embedding=None, properties={}, name='number of apricots'), Relation(label='ALIAS_OF', source_id='number of apricots', target_id='apricots', properties={}), EntityNode(label='VALUE', embedding=None, properties={}, name='apricots'))
dict_items([('apples', 'apples'), ('pears', 'pears'), ('bananas', 'bananas'), ('apricots', 'apricots'), ('number of apples', 'number of apples'), ('number of pears', 'number of pear

In [7]:
sent_id = sentences[0].id_

sent_raw = kgraph.get(ids=[sent_id])[0]

print(sent_raw.text)

the company Applied Materials (AMAT) has 100 apples


In [26]:
from typing import List, Dict, Set, Optional, Union, Any
from llama_index.core.base.base_retriever import BaseRetriever
from llama_index.core.schema import QueryBundle, NodeWithScore, TextNode
# from llama_index.core.graph_stores import PropertyGraphStore
from llama_index.core import VectorStoreIndex
from llama_index.core.query_engine import BaseQueryEngine
# from llama_index.core.llms import LLM, OpenAI # Using OpenAI as an example LLM, replace with your actual LLM
from llama_index.core import PromptTemplate
from llama_index.core.response import Response
from llama_index.core.callbacks import CallbackManager

# Assuming 'sentences' is defined and used to initialize sent_index
sent_index = VectorStoreIndex(sentences) # This line is from your original code, keep it if applicable

class CustomGraphRetriever(BaseRetriever):
    """
    Recursive VALUE→EQUATION→VALUE search.
    Outputs terminal VALUE nodes with supporting sentences and all encountered EQUATION nodes.
    """

    # ---- relation labels (adapt to your own schema) -------------
    R_HAS_SENTENCE = {"HAS_SENTENCE"}
    R_EQUATION     = {"HAS_EQUATION", "COMPUTED_BY", "EQUALS"} # e.g., VALUE -[COMPUTED_BY]-> EQUATION
    R_DEPENDS      = {"DEPENDS_ON", "USES", "INPUT"} # e.g., EQUATION -[DEPENDS_ON]-> VALUE
    R_ALIAS        = {"ALIAS_OF", "SAME_AS"}

    def __init__(
        self,
        graph_store: PropertyGraphStore,
        sentence_index: VectorStoreIndex, # This might be used for initial seeding, but not directly in the current _collect logic.
        name_to_id: Dict[str, str],
        max_depth: int = 5,
    ):
        super().__init__()
        self._graph = graph_store
        self._sidx = sentence_index # Kept for consistency, though not used in direct retrieval.
        # map metric names to VALUE node IDs
        self._name2id = {k.lower(): v for k, v in name_to_id.items()}
        self._max_depth = max_depth

    # ----------------- helper: outgoing triples ------------------
    def _outgoing(self, src_id: str, rel_set: Optional[Set[str]] = None):
        """Yields (subject, relation, object) triples where subject.id == src_id."""
        for subj, rel, obj in self._graph.get_triplets(ids=[src_id]):
            if subj.id != src_id:
                continue
            if rel_set is not None and rel.id not in rel_set:
                continue
            yield subj, rel, obj

    # ----------------- recursive collection ---------------------
    def _collect(
        self,
        node_id: str,
        depth: int,
        seen_vals: Set[str], # Tracks visited VALUE node IDs to prevent cycles and redundant processing
        seen_eqs_ids: Set[str], # Tracks EQUATION node IDs that have already been added to output
        terminal_value_nodes_output: List[NodeWithScore],
        equation_nodes_output: List[NodeWithScore],
    ) -> None:
        """
        Recursively collects relevant VALUE nodes with supporting sentences and EQUATION nodes.
        """
        if depth > self._max_depth:
            return

        raw_nodes = self._graph.get(ids=[node_id])
        if not raw_nodes:
            return
        raw_node = raw_nodes[0] # Get the actual node object (EntityNode or ChunkNode)

        score = 1.0 / (depth + 1) # Simple scoring based on depth

        # --- Handle VALUE nodes ---
        if raw_node.label == "VALUE":
            if node_id in seen_vals:
                return # Already processed this VALUE node in a way that prevents re-visiting this path

            # Add to seen_vals *before* exploring paths from this node to prevent infinite loops
            seen_vals.add(node_id)

            # 1. Check for direct evidence (sentences) for this VALUE node
            sent_edges = list(self._outgoing(node_id, self.R_HAS_SENTENCE))
            if sent_edges:
                lines = []
                for _, _, sent_chunk_node_obj in sent_edges: # sent_chunk_node_obj is the ChunkNode entity
                    sent_raw_nodes = self._graph.get(ids=[sent_chunk_node_obj.id])
                    if sent_raw_nodes and hasattr(sent_raw_nodes[0], 'text') and sent_raw_nodes[0].text:
                        lines.append(f"- {sent_raw_nodes[0].text}")
                text = (
                    f"The following information likely contains the value of {raw_node.name}:\n"
                    + "\n".join(lines)
                )
                wrapped = TextNode(text=text, id_=node_id, metadata={"node_type": "VALUE_SENTENCE", "original_label": raw_node.label})
                terminal_value_nodes_output.append(NodeWithScore(node=wrapped, score=score))
                return # This is a terminal VALUE node (direct evidence found), stop further search down this path

            # 2. If no direct sentences, look for an equation that computes this VALUE
            # Assuming VALUE -[R_EQUATION]-> EQUATION
            eq_edges_from_value = list(self._outgoing(node_id, self.R_EQUATION))
            for _, _, eq_obj in eq_edges_from_value:
                if eq_obj.label == "EQUATION": # Ensure the object is an EQUATION node
                    self._collect(eq_obj.id, depth + 1, seen_vals, seen_eqs_ids, terminal_value_nodes_output, equation_nodes_output)

            # 3. Alias hop for VALUE nodes (can lead to another VALUE node that might have sentences/equations)
            for _, _, alias_obj in self._outgoing(node_id, self.R_ALIAS):
                if alias_obj.label == "VALUE": # Ensure the object of alias is also a VALUE node
                    # Alias links should not increase depth, as it's the same logical entity
                    self._collect(alias_obj.id, depth, seen_vals, seen_eqs_ids, terminal_value_nodes_output, equation_nodes_output)

        # --- Handle EQUATION nodes ---
        elif raw_node.label == "EQUATION":
            if node_id in seen_eqs_ids:
                return # Already processed and added this EQUATION node to the output

            # Add the equation to the output list. Assuming equation text is stored in raw_node.name
            textnode = TextNode(text=raw_node.name, id_=node_id, metadata={"node_type": "EQUATION", "original_label": raw_node.label})
            equation_nodes_output.append(NodeWithScore(node=textnode, score=score))
            seen_eqs_ids.add(node_id) # Mark this equation as processed for output

            # Recurse on every dependency (input VALUE node) of this equation
            # Assuming EQUATION -[R_DEPENDS]-> VALUE
            dep_edges = list(self._outgoing(node_id, self.R_DEPENDS))
            for _, _, dep_obj in dep_edges:
                if dep_obj.label == "VALUE": # Ensure the object of dependency is a VALUE node
                    self._collect(dep_obj.id, depth + 1, seen_vals, seen_eqs_ids, terminal_value_nodes_output, equation_nodes_output)
        else:
            pass # Ignore other node types for now.

    # ----------------- public API --------------------------------
    def _retrieve(
        self,
        query: QueryBundle,
    ) -> List[NodeWithScore]:
        q = query.query_str.lower()
        
        # 1. Seed VALUE nodes mentioned in the query
        seeds = [nid for name, nid in self._name2id.items() if name in q]
        if not seeds:
            print("No direct name match found in query. Using first available known metrics as seeds.")
            seeds = list(self._name2id.values())
            if not seeds:
                return []

        # Initialize the lists to collect results during recursion
        terminal_value_nodes_output: List[NodeWithScore] = []
        equation_nodes_output: List[NodeWithScore] = []
        # Sets to track visited nodes and prevent cycles/duplicates
        seen_vals: Set[str] = set()
        seen_eqs_ids: Set[str] = set()

        # 2. Collect results via recursive tree search
        for vid in seeds:
            self._collect(vid, 0, seen_vals, seen_eqs_ids, terminal_value_nodes_output, equation_nodes_output)

        # 3. Combine all collected results
        all_results = terminal_value_nodes_output + equation_nodes_output

        # 4. Sort the results by score (highest first)
        all_results.sort(key=lambda x: x.score, reverse=True)

        return all_results

    def retrieve(
        self,
        query: Union[str, QueryBundle],
    ) -> List[NodeWithScore]:
        """Public retrieve method."""
        # Wrap string into QueryBundle if needed
        if not isinstance(query, QueryBundle):
            query = QueryBundle(query_str=str(query))
        return self._retrieve(query)


class CustomGraphQueryEngine(BaseQueryEngine):
    """
    Custom Query Engine that uses CustomGraphRetriever and an LLM to answer queries.
    """
    def __init__(self, retriever: CustomGraphRetriever, llm: OpenAI):
        # Pass a default CallbackManager to the superclass constructor
        super().__init__(callback_manager=CallbackManager()) 
        self._retriever = retriever
        self._llm = llm
        
        # Define the prompt template
        self._prompt_template = PromptTemplate(
            """\
Answer the following query using the data and equations.
----------------------------------------------------------
QUERY:
{query_str}
----------------------------------------------------------
DATA:
{supporting_data}
----------------------------------------------------------
EQUATIONS:
{equations}
----------------------------------------------------------
Concisely answer to the query up to 2 decimal places.
"""
        )

    def _get_prompt_modules(self) -> Dict[str, Any]:
        """Get prompt sub-modules for serialization and prompt management."""
        return {"_prompt_template": self._prompt_template}

    def _query(self, query_bundle: QueryBundle) -> Response:
        """
        Executes the query by retrieving nodes, formatting a prompt, and calling the LLM.
        """
        # 1. Call the custom retriever to get relevant nodes
        retrieved_nodes: List[NodeWithScore] = self._retriever.retrieve(query_bundle)

        supporting_data_list = []
        equations_list = []

        # 2. Separate retrieved nodes into DATA (VALUE_SENTENCE) and EQUATIONS based on metadata
        for node_with_score in retrieved_nodes:
            node_type = node_with_score.node.metadata.get("node_type")
            if node_type == "VALUE_SENTENCE":
                supporting_data_list.append(node_with_score.node.text)
            elif node_type == "EQUATION":
                equations_list.append(node_with_score.node.text)

        # Join the collected data and equations
        supporting_data_str = "\n".join(supporting_data_list) if supporting_data_list else "No direct supporting data found."
        equations_str = "\n".join(equations_list) if equations_list else "No equations found."

        # 3. Format the prompt using the collected information
        formatted_prompt = self._prompt_template.format(
            query_str=query_bundle.query_str,
            supporting_data=supporting_data_str,
            equations=equations_str
        )

        print("formatted_prompt:\n", formatted_prompt)

        # 4. Call the LLM with the formatted prompt
        llm_response = self._llm.complete(formatted_prompt)

        # 5. Return the LLM's reply as a LlamaIndex Response object
        return Response(response=llm_response.text)

    async def _aquery(self, query_bundle: QueryBundle) -> Response:
        """
        Asynchronous query execution (not implemented for this example).
        """
        raise NotImplementedError("Async query not implemented yet.")


In [27]:
retriever = CustomGraphRetriever(kgraph, sent_index, name_to_id, max_depth=10)

nodes = retriever.retrieve(
    "how many apricots does the company have?"
)

[print(n) for n in nodes]

Node ID: number of apricots = number of bananas
Text: number of apricots = number of bananas
Score:  0.500

Node ID: number of bananas = number of apples - number of pears
Text: number of bananas = number of apples - number of pears
Score:  0.250

Node ID: apples
Text: The following information likely contains the value of apples: -
the company Applied Materials (AMAT) has 100 apples
Score:  0.200

Node ID: pears
Text: The following information likely contains the value of pears: -
the company Applied Materials (AMAT) has 10 pears
Score:  0.200



[None, None, None, None]

In [28]:
query_engine = CustomGraphQueryEngine(retriever, llm=OpenAI(model="gpt-4o-mini", temperature=0))

response = query_engine.query(
    "how many apricots does the company have?"
)

print(response)

formatted_prompt:
 Answer the following query using the data and equations.
----------------------------------------------------------
QUERY:
how many apricots does the company have?
----------------------------------------------------------
DATA:
The following information likely contains the value of apples:
- the company Applied Materials (AMAT) has 100 apples
The following information likely contains the value of pears:
- the company Applied Materials (AMAT) has 10 pears
----------------------------------------------------------
EQUATIONS:
number of apricots = number of bananas
number of bananas = number of apples - number of pears
----------------------------------------------------------
Concisely answer to the query up to 2 decimal places.

The company has 90 apricots.


In [None]:
from llama_index.core import PromptTemplate

llm = OpenAI(model="gpt-4o-mini", temperature=0)

query_str = "how many apricots does the company have?"

stream = await llm.astream(PromptTemplate(f"""
Answer the following query using the data and equations.
----------------------------------------------------------
QUERY:
{query_str}
----------------------------------------------------------
DATA:
The company has 102 apples.
The company has 30.53 pears.
----------------------------------------------------------
EQUATIONS:
number of apricots = number of bananas
number of bananas = number of apples - number of pears
----------------------------------------------------------
Concisely answer to the query up to 2 decimal places.
""")
)


async for token in stream:
    print(token, end="")

