In [None]:
from typing import List, TypedDict, Literal

from dotenv import load_dotenv

# LangChain / LangGraph imports
from langchain.schema import Document
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_community.vectorstores import Chroma
from langchain_core.messages import BaseMessage
from langchain_core.tools import tool

from langgraph.graph import END, StateGraph, START
from langgraph.prebuilt import ToolNode


# -------------------------------------------------------------------
# 1. Environment
# -------------------------------------------------------------------
# Loads OPENAI_API_KEY (and others) from your .env file
load_dotenv()


# -------------------------------------------------------------------
# 2. Documents + embeddings + Chroma DB
# -------------------------------------------------------------------
embeddings = OpenAIEmbeddings()

docs = [
    Document(
        page_content=(
            "Peak Performance Gym was founded in 2015 by former Olympic athletic Marcus Chen. "
            "With over 15 years of experience in personal training and fitness, Marcus wanted "
            "to create a gym focused on personalized training and holistic wellness."
        ),
        metadata={"source": "about.txt"},
    ),
    Document(
        page_content=(
            "Peak Performance Gym is open Monday through Friday from 5:00 AM to 11:00 PM. "
            "On weekends, our hours are 6:00 AM to 9:00 PM."
        ),
        metadata={"source": "hours.txt"},
    ),
    Document(
        page_content=(
            "Membership plans include: Basic ($30/month) with access to gym floor and basic "
            "equipment, Standard ($50/month) with access to all equipment and 2 classes/week, "
            "Premium ($80/month) with unlimited classes and full access."
        ),
        metadata={"source": "membership_plans.txt"},
    ),
    Document(
        page_content=(
            "Classes at Peak Performance Gym include Yoga (beginner, intermediate, advanced), "
            "HIIT (high intensity interval training), Spin, Zumba, and Strength Training. "
            "All classes are taught by certified professionals with years of experience."
        ),
        metadata={"source": "classes.txt"},
    ),
    Document(
        page_content=(
            "Facilities include a 10,000 sq ft gym floor with free weights, strength training "
            "machines, cardio zone with 30+ machines, sauna, steam room, juice bar, and a "
            "lounge area for members to relax and socialize."
        ),
        metadata={"source": "facilities.txt"},
    ),
]

# NOTE: Make sure you've run `pip install chromadb` in your environment.
db = Chroma.from_documents(docs, embeddings)
print("✅ Vector store created")


# -------------------------------------------------------------------
# 3. State definition
# -------------------------------------------------------------------
class AgentState(TypedDict):
    messages: List[BaseMessage]


# -------------------------------------------------------------------
# 4. Example tool(s)
# -------------------------------------------------------------------
@tool
def add(x: int, y: int) -> int:
    """Add two integers."""
    return x + y


tools = [add]


# -------------------------------------------------------------------
# 5. Agent node
# -------------------------------------------------------------------
def agent(state: AgentState) -> AgentState:
    """
    Agent node:
      - takes the conversation messages from state
      - calls ChatOpenAI with bound tools
      - appends the assistant response as the next message
    """
    messages = state["messages"]

    # Choose an actual model name that exists for your account
    model = ChatOpenAI(model="gpt-4o-mini")  # or "gpt-4.1-mini", etc.
    model_with_tools = model.bind_tools(tools)

    response = model_with_tools.invoke(messages)
    return {"messages": [response]}


# -------------------------------------------------------------------
# 6. Router: should we call tools or end?
# -------------------------------------------------------------------
def should_continue(state: AgentState) -> Literal["tools", END]:
    """
    Routing function used by LangGraph.
    If the last assistant message contains tool calls, go to the 'tools' node.
    Otherwise, stop the graph (END).
    """
    messages = state["messages"]
    last_message = messages[-1]

    tool_calls = getattr(last_message, "tool_calls", None)
    if tool_calls:
        return "tools"
    return END


# -------------------------------------------------------------------
# 7. Build workflow graph
# -------------------------------------------------------------------
workflow = StateGraph(AgentState)

# Add nodes
workflow.add_node("agent", agent)

tool_node = ToolNode(tools)
workflow.add_node("tools", tool_node)

# Edges
workflow.add_edge(START, "agent")         # Start → agent
workflow.add_conditional_edges(          # agent → tools or END
    "agent",
    should_continue,
)
workflow.add_edge("tools", "agent")      # tools → agent

# Compile graph
graph = workflow.compile()
print("✅ Graph compiled successfully")
