In [None]:
from dotenv import load_dotenv

load_dotenv()

In [None]:
from langchain.schema import Document
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_community.vectorstores import Chroma

embedding_function = OpenAIEmbeddings()
docs = [
    Document(
        page_content="Bella Vista is owned by Antonio Rossi, a renowned chef with over 20 years of experience in the culinary industry. He started Bella Vista to bring authentic Italian flavors to the community.",
        metadata={"source": "owner.txt"},
    ),
    Document(
        page_content="Bella Vista offers a range of dishes with prices that cater to various budgets. Appetizers start at $8, main courses range from $15 to $35, and desserts are priced between $6 and $12.",
        metadata={"source": "dishes.txt"},
    ),
    Document(
        page_content="Bella Vista is open from Monday to Sunday. Weekday hours are 11:00 AM to 10:00 PM, while weekend hours are extended from 11:00 AM to 11:00 PM.",
        metadata={"source": "restaurant_info.txt"},
    ),
    Document(
        page_content="Bella Vista offers a variety of menus including a lunch menu, dinner menu, and a special weekend brunch menu. The lunch menu features light Italian fare, the dinner menu offers a more extensive selection of traditional and contemporary dishes, and the brunch menu includes both classic breakfast items and Italian specialties.",
        metadata={"source": "restaurant_info.txt"},
    ),
]

db = Chroma.from_documents(docs, embedding_function)
# retriever = db.as_retriever()

retriever = db.as_retriever(search_kwargs={"k": 2})

In [None]:
from langchain_core.prompts import ChatPromptTemplate

template = """Answer the question based on the following context and the Chat history. Especially take the latest question into consideration:

Chathistory: {history}

Context: {context}

Question: {question}
"""
prompt = ChatPromptTemplate.from_template(template)

In [None]:
llm = ChatOpenAI(model="gpt-4o-mini")
rag_chain = prompt | llm

In [None]:
from typing import TypedDict, List
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage
from langchain.schema import Document
from pydantic import BaseModel, Field
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langgraph.graph import StateGraph, END


""" creating the state ( data storage in simple term to be passed from node to node), nodes and routers. """

class AgentState(TypedDict):
    messages: List[BaseMessage]
    documents: List[Document]
    on_topic: str
    rephrased_question: str
    proceed_to_generate: bool
    rephrase_count: int
    question: HumanMessage


class GradeQuestion(BaseModel):
    score: str = Field(
        description="Question is about the specified topics? If yes -> 'Yes' if not -> 'No'"
    )


""" 
Purpose: If the user's question depends on past messages, rewrite it to be standalone. 
e.g : 
Before: "Also on Sunday?"
After: "Is Bella Vista restaurant open on Sundays?"
"""
def question_rewriter(state: AgentState):
    print(f"Entering question_rewriter with following state: {state}")

    # Reset state variables except for 'question' and 'messages'
    state["documents"] = []
    state["on_topic"] = ""
    state["rephrased_question"] = ""
    state["proceed_to_generate"] = False
    state["rephrase_count"] = 0

    if "messages" not in state or state["messages"] is None:
        state["messages"] = []

    if state["question"] not in state["messages"]:
        state["messages"].append(state["question"])

    """ 
    if conversation is >  1 message then we'll take the conversation context to create
    standalone question. else the questions itself is the rephrased question. 
    """
    if len(state["messages"]) > 1:
        """ slicing a list stored inside state["messages"] and excluding the last element. """
        conversation = state["messages"][:-1]
        current_question = state["question"].content
        messages = [
            SystemMessage(
                content="You are a helpful assistant that rephrases the user's question to be a standalone question optimized for retrieval."
            )
        ]
        """ .extend(conversation) → Adds all elements from conversation to messages. """
        messages.extend(conversation)
        messages.append(HumanMessage(content=current_question))
        
        rephrase_prompt = ChatPromptTemplate.from_messages(messages)
        """ 
        using pipe operator is better rather than the below invoking method. 
        chain = rephrase_prompt | llm
        response = chain.invoke({}) 
        """
        llm = ChatOpenAI(model="gpt-4o-mini")
        prompt = rephrase_prompt.format() #format for llm understanding into str
        response = llm.invoke(prompt)


        """ ALSO COULD HAVE DONE SIMILARLY like previous implementations """
        """ 
        template = '''
        Rephrase the user's question to be a standalone query that is self-contained.
        If conversation history exists, use it to improve the rephrasing.
        
        Conversation History: {context}
        Question: {question}
        
        Rephrased Question:
        '''

        # Create a PromptTemplate object
        rephrase_prompt = PromptTemplate.from_template(template=template)

        # would have to convert to normal str cause the message in state is actually a list of basemessages
        # or, simply the message property of state to a different type which directly supports conversation.
        conversation_text = "\n".join(msg.content for msg in conversation)

        # Define the LLM model
        llm = ChatOpenAI(model="gpt-4o-mini")

        # Use the pipe (`|`) operator to chain prompt → LLM → invoke
        chain = rephrase_prompt | llm
        response = chain.invoke({"context": conversation_text, "question": current_question})
        """


        better_question = response.content.strip()
        print(f"question_rewriter: Rephrased question: {better_question}")
        state["rephrased_question"] = better_question
    else:
        state["rephrased_question"] = state["question"].content
    return state


""" 
this node, 
this determines if the question is on topic or off topic, 
can't we simply skip this step ? so it will simply fetch documents from vector DB accordingly, but then grade it, if its not relevant it will try again..
max 3 times then if not then it means it was a off topic question. 
thats a better workflow i believe? rather than hardcoding questions , if not one of them then off topic.

Your suggestion has merit - instead of explicitly classifying questions as on-topic or off-topic upfront, we could:

Always attempt to retrieve documents
Grade the relevance of retrieved documents
If no relevant documents are found after multiple attempts at question refinement, conclude the question was likely off-topic


drawbacks:

Efficiency: You'd always make at least one retrieval call, which costs time and potentially money if using paid vector storage.
Precision: The explicit classifier could be more precise about certain topics that are clearly out of scope.
User Experience: A user asking about completely unrelated topics might get a response like "I couldn't find information about that" rather than a more direct "I can only answer questions about restaurant hours, prices, and ownership."
Resource Usage: More database queries and potentially more LLM calls if you're constantly trying to refine irrelevant questions.

Hybrid approach:

def lightweight_classifier(state: AgentState):
    \"""A very simple classifier that only filters out obviously unrelated topics\"""
    question = state["rephrased_question"].lower()
    
    # List of obviously off-topic keywords
    off_topic_indicators = [
        "stock market", "weather", "sports scores", "politics", 
        "movie times", "crypto", "bitcoin"
    ]
    
    # Only reject if obviously off-topic
    for topic in off_topic_indicators:
        if topic in question:
            state["off_topic"] = True
            return state
    
    # Otherwise, proceed to retrieval
    state["off_topic"] = False
    return state

def router_after_classification(state: AgentState):
    if state["off_topic"]:
        return "off_topic_response"
    return "retrieve"

"""
def question_classifier(state: AgentState):
    print("Entering question_classifier")
    system_message = SystemMessage(
        content="""You are a classifier that determines whether a user's question is about one of the following topics:

    1. Information about the owner of Bella Vista, which is Antonio Rossi.
    2. Prices of dishes at Bella Vista (restaurant).
    3. Opening hours of Bella Vista (restaurant).

    If the question IS about any of these topics, respond with 'Yes'. Otherwise, respond with 'No'."""
    )

    human_message = HumanMessage(
        content=f"User question: {state['rephrased_question']}"
    )
    grade_prompt = ChatPromptTemplate.from_messages([system_message, human_message])
    llm = ChatOpenAI(model="gpt-4o-mini")
    """ structure pydanctic output which is: 'score': yes """
    structured_llm = llm.with_structured_output(GradeQuestion)
    grader_llm = grade_prompt | structured_llm
    result = grader_llm.invoke({})
    # strip() strips all the whitespaces around the string.
    state["on_topic"] = result.score.strip()
    print(f"question_classifier: on_topic = {state['on_topic']}")
    return state


def on_topic_router(state: AgentState):
    print("Entering on_topic_router")
    on_topic = state.get("on_topic", "").strip().lower()
    if on_topic == "yes":
        print("Routing to retrieve")
        return "retrieve"
    else:
        print("Routing to off_topic_response")
        return "off_topic_response"




def retrieve(state: AgentState):
    print("Entering retrieve")

    """ 
    This might seem like it's skipping the embedding step, but here's what's actually happening behind the scenes:

    When you call retriever.invoke(query), the retriever internally handles the embedding process. The embedding happens implicitly within the retriever's implementation, not explicitly in your code.
    """
    """ 
    What's happening under the hood:

    The retriever takes your text query (state["rephrased_question"])
    It passes this text to the embedding model that was configured when the vector store was created
    The embedding model converts the text to a vector
    The vector is used to perform similarity search against the vectors in the database
    The matching documents are returned
    """

    documents = retriever.invoke(state["rephrased_question"])
    print(f"retrieve: Retrieved {len(documents)} documents")
    state["documents"] = documents
    return state


class GradeDocument(BaseModel):
    score: str = Field(
        description="Document is relevant to the question? If yes -> 'Yes' if not -> 'No'"
    )


def retrieval_grader(state: AgentState):
    print("Entering retrieval_grader")
    system_message = SystemMessage(
        content="""You are a grader assessing the relevance of a retrieved document to a user question.
                    Only answer with 'Yes' or 'No'.

                    If the document contains information relevant to the user's question, respond with 'Yes'.
                    Otherwise, respond with 'No'."""
    )

    llm = ChatOpenAI(model="gpt-4o-mini")
    structured_llm = llm.with_structured_output(GradeDocument)

    relevant_docs = []
    for doc in state["documents"]:
        human_message = HumanMessage(
            content=f"User question: {state['rephrased_question']}\n\nRetrieved document:\n{doc.page_content}"
        )
        grade_prompt = ChatPromptTemplate.from_messages([system_message, human_message])
        grader_llm = grade_prompt | structured_llm
        result = grader_llm.invoke({})
        print(
            f"Grading document: {doc.page_content[:30]}... Result: {result.score.strip()}"
        )
        if result.score.strip().lower() == "yes":
            relevant_docs.append(doc)
    state["documents"] = relevant_docs
    state["proceed_to_generate"] = len(relevant_docs) > 0
    print(f"retrieval_grader: proceed_to_generate = {state['proceed_to_generate']}")
    return state


def proceed_router(state: AgentState):
    print("Entering proceed_router")
    rephrase_count = state.get("rephrase_count", 0)
    if state.get("proceed_to_generate", False):
        print("Routing to generate_answer")
        return "generate_answer"
    elif rephrase_count >= 2:
        print("Maximum rephrase attempts reached. Cannot find relevant documents.")
        return "cannot_answer"
    else:
        print("Routing to refine_question")
        return "refine_question"


def refine_question(state: AgentState):
    print("Entering refine_question")
    rephrase_count = state.get("rephrase_count", 0)
    if rephrase_count >= 2:
        print("Maximum rephrase attempts reached")
        return state
    question_to_refine = state["rephrased_question"]
    system_message = SystemMessage(
        content="""You are a helpful assistant that slightly refines the user's question to improve retrieval results.
Provide a slightly adjusted version of the question."""
    )
    human_message = HumanMessage(
        content=f"Original question: {question_to_refine}\n\nProvide a slightly refined question."
    )
    refine_prompt = ChatPromptTemplate.from_messages([system_message, human_message])
    llm = ChatOpenAI(model="gpt-4o-mini")
    prompt = refine_prompt.format()
    response = llm.invoke(prompt)


    refined_question = response.content.strip()
    print(f"refine_question: Refined question: {refined_question}")
    
    state["rephrased_question"] = refined_question
    state["rephrase_count"] = rephrase_count + 1
    return state


def generate_answer(state: AgentState):
    print("Entering generate_answer")
    if "messages" not in state or state["messages"] is None:
        raise ValueError("State must include 'messages' before generating an answer.")

    history = state["messages"]
    documents = state["documents"]
    rephrased_question = state["rephrased_question"]

    response = rag_chain.invoke(
        {"history": history, "context": documents, "question": rephrased_question}
    )

    generation = response.content.strip()

    state["messages"].append(AIMessage(content=generation))
    print(f"generate_answer: Generated response: {generation}")
    return state


def cannot_answer(state: AgentState):
    print("Entering cannot_answer")
    if "messages" not in state or state["messages"] is None:
        state["messages"] = []
    state["messages"].append(
        AIMessage(
            content="I'm sorry, but I cannot find the information you're looking for."
        )
    )
    return state


def off_topic_response(state: AgentState):
    print("Entering off_topic_response")
    if "messages" not in state or state["messages"] is None:
        state["messages"] = []
    state["messages"].append(AIMessage(content="I can't respond to that!"))
    return state


'''
Instead, you could have the LLM generate contextually appropriate responses by:

For off-topic questions: Have the LLM explain what topics it can assist with while politely declining the current question
For no-relevant-documents scenarios: Have the LLM generate a response that acknowledges it couldn't find specific information

Here's how you could refactor these nodes:

def cannot_answer(state: AgentState):
    print("Entering cannot_answer")
    if "messages" not in state or state["messages"] is None:
        state["messages"] = []
    
    system_message = SystemMessage(
        content="""You are a helpful assistant that specializes in answering questions about 
        restaurant Bella Vista. A user has asked a question, but after several attempts, 
        no relevant information could be found. Politely explain that you don't have the 
        specific information they're looking for, but mention what topics you can help with 
        (restaurant hours, menu prices, and owner information)."""
    )
    
    human_message = HumanMessage(
        content=f"Question: {state['rephrased_question']}"
    )
    
    llm = ChatOpenAI(model="gpt-4o-mini")
    response = llm.invoke([system_message, human_message])
    
    state["messages"].append(response)
    return state


def off_topic_response(state: AgentState):
    print("Entering off_topic_response")
    if "messages" not in state or state["messages"] is None:
        state["messages"] = []
    
    system_message = SystemMessage(
        content="""You are a helpful assistant that specializes in answering questions about 
        restaurant Bella Vista. The user has asked a question that is outside your area of 
        expertise. Politely explain that you can only answer questions about Bella Vista 
        restaurant, specifically about its opening hours, menu prices, and owner information."""
    )
    
    human_message = HumanMessage(
        content=f"Question: {state['rephrased_question']}"
    )
    
    llm = ChatOpenAI(model="gpt-4o-mini")
    response = llm.invoke([system_message, human_message])
    
    state["messages"].append(response)
    return state

'''

In [None]:
from langgraph.checkpoint.memory import MemorySaver

checkpointer = MemorySaver()

In [None]:
# Workflow
workflow = StateGraph(AgentState)
workflow.add_node("question_rewriter", question_rewriter)
workflow.add_node("question_classifier", question_classifier)
workflow.add_node("off_topic_response", off_topic_response)
workflow.add_node("retrieve", retrieve)
workflow.add_node("retrieval_grader", retrieval_grader)
workflow.add_node("generate_answer", generate_answer)
workflow.add_node("refine_question", refine_question)
workflow.add_node("cannot_answer", cannot_answer)

workflow.add_edge("question_rewriter", "question_classifier")
workflow.add_conditional_edges(
    "question_classifier",
    on_topic_router,
    {
        "retrieve": "retrieve",
        "off_topic_response": "off_topic_response",
    },
)
workflow.add_edge("retrieve", "retrieval_grader")
workflow.add_conditional_edges(
    "retrieval_grader",
    proceed_router,
    {
        "generate_answer": "generate_answer",
        "refine_question": "refine_question",
        "cannot_answer": "cannot_answer",
    },
)
workflow.add_edge("refine_question", "retrieve")
workflow.add_edge("generate_answer", END)
workflow.add_edge("cannot_answer", END)
workflow.add_edge("off_topic_response", END)
workflow.set_entry_point("question_rewriter")
graph = workflow.compile(checkpointer=checkpointer)

In [None]:
from IPython.display import Image, display
from langchain_core.runnables.graph import MermaidDrawMethod

display(
    Image(
        graph.get_graph().draw_mermaid_png(
            draw_method=MermaidDrawMethod.API,
        )
    )
)

### Off topic

In [None]:
input_data = {"question": HumanMessage(content="How is the weather?")}
graph.invoke(input=input_data, config={"configurable": {"thread_id": 1}})

### No docs found

In [None]:
input_data = {
    "question": HumanMessage(
        content="How old is the owner of the restaurant Bella Vista?"
    )
}
graph.invoke(input=input_data, config={"configurable": {"thread_id": 2}})

### Rag with History

In [None]:
input_data = {
    """ because of the pydantic state model, this question is auto stored in the state in 'question' key. """
    "question": HumanMessage(content="When does the Bella Vista restaurant open?")
}


graph.invoke(input=input_data, config={"configurable": {"thread_id": 3}})

In [None]:
input_data = {"question": HumanMessage(content="Also on sunday?")}

graph.invoke(input=input_data, config={"configurable": {"thread_id": 3}})

""" 
The critical part is this:

When the question rewriter runs, it expects state["messages"] to contain the conversation history
But for the second invocation, your input only contains the new question: {"question": HumanMessage(content="Also on sunday?")}
Without MemorySaver, the state would be initialized fresh each time with just your input, so state["messages"] would be empty or only contain the current question. so we wouldn't have anything in state['messages'] , so won't be able to provide the llm with the chat_history in both the nodes where we need send chat_history to the llm ( question_rewriter & generate )

This is why you need MemorySaver: It persists the messages array between invocations, allowing your second invocation to access the conversation history from the first invocation. meaning it doesn't reset the state['messages'] after 2nd invocation.
"""