In [None]:
import os
import functools
from pprint import pprint
from dotenv import load_dotenv

from langchain_openai import AzureChatOpenAI
from langchain_community.vectorstores.azuresearch import AzureSearch
from langchain_openai import AzureOpenAIEmbeddings
from langgraph.prebuilt import ToolNode
from langgraph.graph import END, StateGraph
from langgraph.checkpoint.memory import MemorySaver

from utils.prompts import (
    default_router_prompt,
    default_chat_prompt,
    default_tool_using_prompt,
)
from utils.router import AgentState, classifier_router
from utils.agent import create_agent, agent_node, invoke_graph
from utils.tools import search_web_tool, create_vector_store_tool

# Load getenvment variables
load_dotenv()

# Set up Clients

In [2]:
gpt4o_client = AzureChatOpenAI(
    model=os.getenv("AZURE_OPENAI_MODEL"),
    azure_deployment=os.getenv("AZURE_OPENAI_DEPLOYMENT_NAME"),
    api_key=os.getenv("AZURE_OPENAI_API_KEY"),
    api_version=os.getenv("AZURE_OPENAI_VERSION"),
    azure_endpoint=os.getenv("AZURE_OPENAI_GPOT4O_ENDPOINT"),
    temperature=os.getenv("AZURE_OPENAI_TEMPERATURE"),
)

In [3]:
embeddings: AzureOpenAIEmbeddings = AzureOpenAIEmbeddings(
    azure_deployment=os.getenv("AZURE_OPENAI_EMBEDDER_DEPLOYMENT"),
    openai_api_version=os.getenv("AZURE_OPENAI_VERSION"),
    azure_endpoint=os.getenv("AZURE_OPENAI_EMBEDDER_ENDPOINT"),
    api_key=os.getenv("AZURE_OPENAI_API_KEY"),
    dimensions=os.getenv("AZURE_OPENAI_EMBEDDER_DIMENSIONS")
)

In [4]:
vector_store: AzureSearch = AzureSearch(
    azure_search_endpoint=os.getenv("AI_SEARCH_ENDPOINT"),
    azure_search_key=os.getenv("AI_SEARCH_API_KEY"),
    index_name=os.getenv("INDEX_NAME"),
    embedding_function=embeddings.embed_query,
)

# Create Agents

In [5]:
# Create retrieval tool
retriever_tool = create_vector_store_tool(vector_store)

In [6]:
classifier_agent = create_agent(gpt4o_client, default_router_prompt)
chat_agent = create_agent(gpt4o_client, default_chat_prompt)
tool_using_agent = create_agent(
    gpt4o_client, default_tool_using_prompt, tools=[search_web_tool, retriever_tool]
)

## Create Nodes

In [7]:
classifier_node = functools.partial(
    agent_node, agent=classifier_agent, name="classifier"
)
chat_node = functools.partial(agent_node, agent=chat_agent, name="chat")
tool_using_node = functools.partial(
    agent_node, agent=tool_using_agent, name="tool_using"
)

In [8]:
# Define a tool node
tools = [search_web_tool, retriever_tool]
tool_node = ToolNode(tools)

# Create Graph

In [9]:
# Define the Graph Workflow
workflow = StateGraph(AgentState)

# Add Nodes to the Workflow
workflow.add_node("classifier", classifier_node)
workflow.add_node("chat", chat_node)
workflow.add_node("tool_using", tool_using_node)
workflow.add_node("tool", tool_node)

# Define Edges
workflow.add_conditional_edges(
    "classifier",
    classifier_router,
    {"SPECIFIC": "tool_using", "GENERAL": "chat"},
)

workflow.add_edge("tool_using", "tool")
workflow.add_edge("tool", "chat")
workflow.add_edge("chat", END)

# Set Entry Point
workflow.set_entry_point("classifier")

# Compile the Graph
graph = workflow.compile()

In [10]:
# Initialize memory to persist memory between graph runs
checkpointer = MemorySaver()
graph = workflow.compile(debug=True, checkpointer=checkpointer)

# Test agent

In [None]:
result = invoke_graph(graph=graph, messages="Which country started WWII?")

In [None]:
pprint(result)

In [None]:
result = invoke_graph(graph=graph, messages="What are the main risks of Tesla?")

In [None]:
pprint(result)