In [1]:
from langchain.chat_models import init_chat_model
from typing import Union
from fastapi import FastAPI
from typing import Annotated, TypedDict
from langgraph.graph.message import add_messages
from langgraph.graph import StateGraph, START, END
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import HumanMessage
from langchain.chat_models import init_chat_model
from langgraph.checkpoint.memory import MemorySaver
from langchain_core.messages import ToolMessage
from langchain_tavily import TavilySearch
from pydantic import BaseModel
from IPython.display import Image, display
import getpass
import os
import dotenv
import sys


root_path = "~"
env_path = f"{os.path.expanduser(root_path)}/fold/.env"
dotenv.load_dotenv(env_path)


app = FastAPI()
memory = MemorySaver()


class State(TypedDict):
    messages: Annotated[list, add_messages]


class Chat(BaseModel):
    text: str
    search: Union[bool, None] = None


class BasicToolNode:
    def __init__(self, tools: list) -> None:
        self.tools = {tool.name: tool for tool in tools}
        print("here", self.tools)

    def get_tools(self):
        return self.tools

    def __call__(self, input):
        if messages := input.get("messages", []):
            last_msg = messages[-1]
        else:
            raise ValueError("message not provided")

        outputs = []
        for tool_call in last_msg.tool_calls:
            tool_result = self.tools[tool_call["name"]].invoke(tool_call["args"])
            outputs.append(
                ToolMessage(
                    content=tool_result,
                    name=tool_call["name"],
                    tool_call_id=tool_call["id"],
                )
            )

            return {"messages": outputs}


def route_node(state):
    if isinstance(state, list) and len(state) > 0:
        ai_message = state[-1]
    elif messages := state.get("messages", []):
        ai_message = state["messages"][-1]
    else:
        raise ValueError("""state -> messages is not provided or null""")

    if hasattr(ai_message, "tool_calls") and len(ai_message["tool_calls"]) > 0:
        return "tools"
    return END


graph = StateGraph(State)
tool = TavilySearch(max_results=2)
tools = [tool]
tool_node = BasicToolNode(tools)


def main():

    val = dotenv.get_key(
        dotenv_path=env_path,
        key_to_get="GROQ_API_KEY",
    )
    if not os.environ.get("GROQ_API_KEY") and val:
        os.environ["GROQ_API_KEY"] = getpass.getpass(
            "Provide your Groq model Password :"
        )

    else:
        print("loaded env variable")
    llm = init_chat_model(model="llama3-8b-8192", model_provider="groq", api_key=val)
    llm_with_tools = llm.bind_tools(tools)

    def chatbot(state: State):
        return {"messages": [llm_with_tools.invoke(state["messages"])]}

    graph.add_node("chatbot", chatbot)
    graph.add_node("tools", tool_node)

    graph.add_edge(START, "chatbot")
    graph.add_conditional_edges("chatbot", route_node, {"tools": "tools", END: END})
    graph.add_edge("tools", "chatbot")
    App = graph.compile()

    try:
        display(Image(App.get_graph().draw_mermaid_png()))
    except Exception:
        # This requires some extra dependencies and is optional
        pass

    @app.get("/")
    def test():
        msg = [HumanMessage("Hello world")]
        res = llm.invoke(msg)
        return res.content

    @app.post("/talk")
    def talk(chat: Chat):
        if chat.search:

            res = tool.invoke("What is Javascript")
            print(res)
            return res

        msg = [HumanMessage(chat.text)]
        for val in App.stream({"messages": msg}):
            for vals in val.values():
                print("Assistant: ", vals["messages"][-1].content)
                return vals["messages"][-1].content


if __name__ == "main":
    main()


here {'tavily_search': TavilySearch(max_results=2, api_wrapper=TavilySearchAPIWrapper(tavily_api_key=SecretStr('**********')))}
