Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 173 additions & 0 deletions src/examples/langgraph_example/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
from typing import TypedDict, Union, Annotated
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.tools import tool
import operator
from dotenv import load_dotenv
from langchain_openai import ChatOpenAI

from langchain import hub
from langchain.agents import create_openai_tools_agent
import json
from langgraph.graph import StateGraph, END
from langtrace_python_sdk import langtrace, with_langtrace_root_span

load_dotenv()

langtrace.init(write_spans_to_console=False)


class AgentState(TypedDict):
input: str
agent_out: Union[AgentAction, AgentFinish, None]
intermediate_steps: Annotated[list[tuple[AgentAction, str]], operator.add]


ehi_information = """Title: EHI: End-to-end Learning of Hierarchical Index for
Efficient Dense Retrieval
Summary: Dense embedding-based retrieval is now the industry
standard for semantic search and ranking problems, like obtaining relevant web
documents for a given query. Such techniques use a two-stage process: (a)
contrastive learning to train a dual encoder to embed both the query and
documents and (b) approximate nearest neighbor search (ANNS) for finding similar
documents for a given query. These two stages are disjoint; the learned
embeddings might be ill-suited for the ANNS method and vice-versa, leading to
suboptimal performance. In this work, we propose End-to-end Hierarchical
Indexing -- EHI -- that jointly learns both the embeddings and the ANNS
structure to optimize retrieval performance. EHI uses a standard dual encoder
model for embedding queries and documents while learning an inverted file index
(IVF) style tree structure for efficient ANNS. To ensure stable and efficient
learning of discrete tree-based ANNS structure, EHI introduces the notion of
dense path embedding that captures the position of a query/document in the tree.
We demonstrate the effectiveness of EHI on several benchmarks, including
de-facto industry standard MS MARCO (Dev set and TREC DL19) datasets. For
example, with the same compute budget, EHI outperforms state-of-the-art (SOTA)
in by 0.6% (MRR@10) on MS MARCO dev set and by 4.2% (nDCG@10) on TREC DL19
benchmarks.
Author(s): Ramnath Kumar, Anshul Mittal, Nilesh Gupta, Aditya Kusupati,
Inderjit Dhillon, Prateek Jain
Source: https://arxiv.org/pdf/2310.08891.pdf"""


@tool("search")
def search_tool(query: str):
"""Searches for information on the topic of artificial intelligence (AI).
Cannot be used to research any other topics. Search query must be provided
in natural language and be verbose."""
# this is a "RAG" emulator
return ehi_information


@tool("final_answer")
def final_answer_tool(answer: str, source: str):
"""Returns a natural language response to the user in `answer`, and a
`source` which provides citations for where this information came from.
"""
return ""


llm = ChatOpenAI()
prompt = hub.pull("hwchase17/openai-functions-agent")


query_agent_runnable = create_openai_tools_agent(
llm=llm, tools=[final_answer_tool, search_tool], prompt=prompt
)


inputs = {"input": "what are EHI embeddings?", "intermediate_steps": []}

agent_out = query_agent_runnable.invoke(inputs)


def run_query_agent(state: list):
print("> run_query_agent")
agent_out = query_agent_runnable.invoke(state)
return {"agent_out": agent_out}


def execute_search(state: list):
print("> execute_search")
action = state["agent_out"]
tool_call = action[-1].message_log[-1].additional_kwargs["tool_calls"][-1]
out = search_tool.invoke(json.loads(tool_call["function"]["arguments"]))
return {"intermediate_steps": [{"search": str(out)}]}


def router(state: list):
print("> router")
if isinstance(state["agent_out"], list):
return state["agent_out"][-1].tool
else:
return "error"


# finally, we will have a single LLM call that MUST use the final_answer structure
final_answer_llm = llm.bind_tools([final_answer_tool], tool_choice="final_answer")


# this forced final_answer LLM call will be used to structure output from our
# RAG endpoint
def rag_final_answer(state: list):
print("> final_answer")
query = state["input"]
context = state["intermediate_steps"][-1]

prompt = f"""You are a helpful assistant, answer the user's question using the
context provided.

CONTEXT: {context}

QUESTION: {query}
"""
out = final_answer_llm.invoke(prompt)
function_call = out.additional_kwargs["tool_calls"][-1]["function"]["arguments"]
return {"agent_out": function_call}


# we use the same forced final_answer LLM call to handle incorrectly formatted
# output from our query_agent
def handle_error(state: list):
print("> handle_error")
query = state["input"]
prompt = f"""You are a helpful assistant, answer the user's question.

QUESTION: {query}
"""
out = final_answer_llm.invoke(prompt)
function_call = out.additional_kwargs["tool_calls"][-1]["function"]["arguments"]
return {"agent_out": function_call}


@with_langtrace_root_span("run_graph")
def run_graph():
graph = StateGraph(AgentState)

# we have four nodes that will consume our agent state and modify
# our agent state based on some internal process
graph.add_node("query_agent", run_query_agent)
graph.add_node("search", execute_search)
graph.add_node("error", handle_error)
graph.add_node("rag_final_answer", rag_final_answer)
# our graph will always begin with the query agent
graph.set_entry_point("query_agent")
# conditional edges are controlled by our router
graph.add_conditional_edges(
"query_agent",
router,
{
"search": "search",
"error": "error",
"final_answer": END,
},
)
graph.add_edge("search", "rag_final_answer")
graph.add_edge("error", END)
graph.add_edge("rag_final_answer", END)

runnable = graph.compile()

return runnable.invoke({"input": "what are EHI embeddings?"})


if __name__ == "__main__":
run_graph()
2 changes: 1 addition & 1 deletion src/langtrace_python_sdk/instrumentation/crewai/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def _parse_tasks(self, tasks):
for task in tasks:
self.crew["tasks"].append(
{
"agent": task.agent.role,
"agent": task.agent.role if task.agent else None,
"description": task.description,
"async_execution": task.async_execution,
"expected_output": task.expected_output,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,34 +41,29 @@ def _instrument(self, **kwargs):
# List of modules to patch, with their corresponding patch names
modules_to_patch = [
(
"langgraph.graph.graph",
"langgraph.graph.state", # Updated module path
"StateGraph", # Updated class name
[
"add_node",
"add_edge",
"set_entry_point",
"set_finish_point",
"add_conditional_edges",
],
),
)
]

for module_name, methods in modules_to_patch:
module = importlib.import_module(module_name)
for name, obj in inspect.getmembers(
module,
lambda member: inspect.isclass(member)
and member.__module__ == module.__name__,
):
for method_name, _ in inspect.getmembers(
obj, predicate=inspect.isfunction
):
if method_name in methods:
module = f"{name}.{method_name}"
wrap_function_wrapper(
module_name,
module,
patch_graph_methods(module, tracer, version),
)
for module_name, class_name, methods in modules_to_patch:
for method_name in methods:
# Construct the correct path for the method
method_path = f"{class_name}.{method_name}"
wrap_function_wrapper(
module_name,
method_path,
patch_graph_methods(
f"{module_name}.{method_path}", tracer, version
),
)

def _uninstrument(self, **kwargs):
pass
5 changes: 2 additions & 3 deletions src/langtrace_python_sdk/instrumentation/langgraph/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from importlib_metadata import version as v

from langtrace_python_sdk.constants import LANGTRACE_SDK_NAME
from langtrace_python_sdk.utils.llm import set_span_attributes


def patch_graph_methods(method_name, tracer, version):
Expand Down Expand Up @@ -57,9 +58,7 @@ def traced_method(wrapped, instance, args, kwargs):
kind=SpanKind.CLIENT,
context=set_span_in_context(trace.get_current_span()),
) as span:
for field, value in attributes.model_dump(by_alias=True).items():
if value is not None:
span.set_attribute(field, value)
set_span_attributes(span, attributes)
try:
# Attempt to call the original method
result = wrapped(*args, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion src/langtrace_python_sdk/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "3.3.11"
__version__ = "3.3.12"
Loading