Key references:
- [How to invoke the tool with ToolCall to return `Artifacts`](https://python.langchain.com/docs/how_to/tool_artifacts/#invoking-the-tool-with-toolcall)
    - https://js.langchain.com/docs/how_to/tool_artifacts/#invoking-the-tool-with-toolcall
- [Hiding arguments from the model](https://python.langchain.com/docs/how_to/tool_runtime/#other-ways-of-annotating-args)
- [DeepMind's DeepResearch implementation with Langgraph](https://github.com/google-gemini/gemini-fullstack-langgraph-quickstart#)
- [Langchain's deepresearch example](https://github.com/langchain-ai/open_deep_research#)
    - define `from_runnable_config` to 

In [3]:
from typing import Annotated, Any, List, Literal, Optional, Tuple

from langchain_core.callbacks import (
    AsyncCallbackManagerForToolRun,
    CallbackManagerForToolRun,
)
from langchain_core.tools import BaseTool, InjectedToolArg
from langchain_core.tools.base import ArgsSchema
from pydantic import BaseModel, Field

In [4]:
class RetrieverInput(BaseModel):
    query: str = Field(description="User query for information retrieval")
    retrieval_top_k: Annotated[int, InjectedToolArg] = Field(
        5, 
        description="Number of most relevant records to retrieve from the vector store"
    )


In [13]:
from langchain_core.runnables import RunnableConfig


class RetrieverTool(BaseTool):
    """A stateful retriever tool that holds the vector retrieval service."""
    name: str = "Information Retrieval Tool"
    description: str = "Retrieves most relevant records from the vector store for the provided query."
    args_schema: Optional[ArgsSchema] = RetrieverInput
    # return_direct: bool = True
    response_format: Literal["content", "content_and_artifact"] = "content_and_artifact"
    
    client: str = Field("hello", description="This can be initialized client/service")

    
    def _run(self, query: str, retrieval_top_k: int, config: RunnableConfig) -> Tuple[List[str], Any]:
        """Retrieves most relevant records from the vector store for the provided query."""
        # artifact = self.vector_retrieval_service.vector_store.similarity_search(
        #     query=query,
        #     k=self.top_k,
        # )
        
        # response = [res.data for res in artifact]
        # return response, artifact
        # top_k = config["configurable"]["top_k"]
        response, artifact = [query.upper()]*retrieval_top_k, {"metadata": "example_metadata"}
        return response, artifact
    
    async def _arun(self, query: str, config: RunnableConfig) -> Tuple[List[str], Any]:
        """Asynchronous version of the run method."""
        # Simulating async behavior
        ...

config = {"configurable": {"top_k": 3}}

In [14]:
RetrieverTool(client="yooo").client

'yooo'

In [18]:
retriever_tool = RetrieverTool(
    response_format="content_and_artifact",
    client="example_client"
)

retriever_tool.invoke(
    input={"query": "What is LangChain?", "retrieval_top_k": 3}, 
    config=config
)

['WHAT IS LANGCHAIN?', 'WHAT IS LANGCHAIN?', 'WHAT IS LANGCHAIN?']

In [None]:
from langchain_core.messages import ToolCall

# NOTE: only when input is ToolCall, a TypedDict, artifacts will be returned 

retriever_tool.invoke(
    input=ToolCall(
        name="Information Retrieval Tool",
        args={"query": "What is LangChain?", "retrieval_top_k": 3},
        id="123",
        type="tool_call",
    ),
    config=config
)

ToolMessage(content=['WHAT IS LANGCHAIN?', 'WHAT IS LANGCHAIN?', 'WHAT IS LANGCHAIN?'], name='Information Retrieval Tool', tool_call_id='123', artifact={'metadata': 'example_metadata'})

In [None]:
retriever_tool = RetrieverTool(client="example_client")
retriever_tool.get_input_schema().schema()

In [None]:
retriever_tool.tool_call_schema().schema()

In [24]:
from langgraph.prebuilt import ToolNode

retriever_node = ToolNode([RetrieverTool()])

In [30]:
from langchain_core.messages import HumanMessage, AIMessage

# Create a message-based input instead of a string
message_input = AIMessage(
    content="What is LangChain?",
    tool_calls=[
    {
        "name": "Information Retrieval Tool",
        "args": {
            "query": "What is LangChain?"
        },
        "id": "tool_call_1"
    },
])

# Create API style message input

node_output = retriever_node.invoke(
    input={
        "messages": [message_input]
    },
    config=config
)

In [31]:
node_output

{'messages': [ToolMessage(content='["WHAT IS LANGCHAIN?", "WHAT IS LANGCHAIN?", "WHAT IS LANGCHAIN?"]', name='Information Retrieval Tool', tool_call_id='tool_call_1', artifact={'metadata': 'example_metadata'})]}

In [32]:
node_output["messages"][0]

ToolMessage(content='["WHAT IS LANGCHAIN?", "WHAT IS LANGCHAIN?", "WHAT IS LANGCHAIN?"]', name='Information Retrieval Tool', tool_call_id='tool_call_1', artifact={'metadata': 'example_metadata'})

In [33]:
node_output["messages"][0].content

'["WHAT IS LANGCHAIN?", "WHAT IS LANGCHAIN?", "WHAT IS LANGCHAIN?"]'

In [34]:
node_output["messages"][0].artifact

{'metadata': 'example_metadata'}