In [None]:
from typing import Annotated
from typing_extensions import TypedDict
from langchain_community.utilities import ArxivAPIWrapper, WikipediaAPIWrapper
from langchain_community.tools import ArxivQueryRun, WikipediaQueryRun

In [None]:
arxivWrapper = ArxivAPIWrapper(
    top_k_results = 1, 
    doc_content_chars_max = 300
)

arxivTool = ArxivQueryRun(api_wrapper = arxivWrapper)
arxivTool

In [None]:
wikipediaWrapper = WikipediaAPIWrapper(
    top_k_results = 1, 
    doc_content_chars_max = 300
)

wikiTool = WikipediaQueryRun(api_wrapper = wikipediaWrapper)
wikiTool

In [None]:
wikiTool.invoke("Who is Virat Kohli")

In [None]:
arxivTool.invoke("Attention is all you need")

In [None]:
tools = [arxivTool, wikiTool]

## LangGraph Application

In [None]:
from langgraph.graph.message import add_messages

class State(TypedDict):
    messages: Annotated[list, add_messages]

In [None]:
from langgraph.graph import StateGraph, START, END

graphBuilder = StateGraph(State)

In [None]:
from langchain_groq import ChatGroq

llm = ChatGroq(groq_api_key = groqApiKey, model_name = "Gemma2-9b-It")
llm

In [None]:
llmWithTools = llm.bind_tools(tools = tools)

In [None]:
def chatbot(state: State):
    return {"messages": [llmWithTools.invoke(state["messages"])]}

In [None]:
from langgraph.prebuilt import ToolNode, tools_condition

graphBuilder.add_node('chatbot', chatbot)
graphBuilder.add_edge(START, 'chatbot')

toolNode = ToolNode(tools = tools)
graphBuilder.add_node('tools', toolNode)

In [None]:
graphBuilder.add_conditional_edges(
    'chatbot',
    tools_condition
)

graphBuilder.add_edge('tools', 'chatbot')
graphBuilder.add_edge('chatbot', END)

In [None]:
graph = graphBuilder.compile()

In [None]:
from IPython.display import Image, display

try:
    display(Image(graph.get_graph().draw_mermaid_png()))
except Exception:
    pass

## Generating Responses

In [None]:
userInput = "Hello. My name is Keshav Saraogi"

In [None]:
events = graph.stream(
    {'messages': [('user', userInput)]}, stream_mode = 'values'
)

for event in events:
    event['messages'][-1].pretty_print()