In [1]:
from typing import TypedDict, Sequence, Annotated,List, Optional
from dotenv import load_dotenv

from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage, AIMessage
from langchain_google_genai import ChatGoogleGenerativeAI

from langchain_core.tools import tool
from langgraph.graph.message import add_messages
from langgraph.graph import StateGraph, END
from langgraph.prebuilt import ToolNode

In [2]:
load_dotenv()
DEBUG = False

In [3]:
# ---- State Definition ----
class AgentState(TypedDict):
    question: str
    clarity_status: str
    clarification_reason: Annotated[Sequence[BaseMessage], add_messages]

    clarification: Annotated[Sequence[BaseMessage], add_messages]

    answer: Optional[str]

In [4]:
tools = []

In [5]:
from langchain_google_genai import ChatGoogleGenerativeAI

# ---- Model Binding ----
llm = ChatGoogleGenerativeAI(
    model="gemini-1.5-flash-latest", temperature=0.3
).bind_tools(tools)

In [6]:
def is_clear(state: AgentState) -> AgentState:
    """
        determines if the question is clear. 
        clear ? answers : clarify()
    """
    question_to_check = state["question"]
    if DEBUG:
        print(f"DEBUG is_clear: Checking question: '{question_to_check}'")
    
    sys_prompt = SystemMessage(content="You are a helpful assistant that determines if a question is clear. Your primary goal is to understand if you have enough information to provide a comprehensive answer directly. If the question is ambiguous, too broad, or seems to be missing crucial context that a user would typically provide, mark it as unclear. Do not be overly critical for simple questions.")
    messages = [
        sys_prompt,
        HumanMessage(content=f"Is this question clear enough to answer directly? Respond only with 'clear' or 'unclear: <reason why it's unclear and what clarification is needed>'.\n\nQuestion: {question_to_check}")
    ]
    llm_response = llm.invoke(messages)
    result = llm_response.content.strip()
    if DEBUG:
        print(f"DEBUG is_clear: LLM raw response for clarity: '{result}'")

    if result.lower().startswith("unclear"): # More robust check
        try:
            reason = result.split(":", 1)[1].strip()
        except IndexError:
            reason = "The LLM indicated the question was unclear but did not provide a specific reason."
        return {"clarity_status": "unclear", "clarification_reason": [AIMessage(content=reason)]}
    else:
        # If clear, ensure clarification_reason is empty or not updated to clear old reasons
        return {"clarity_status": "clear", "clarification_reason": []}
    
    

In [7]:
def answer(state: AgentState) -> AgentState:
    current_question = state["question"]
    clarification_reasons_messages = state.get("clarification_reason", [])
    user_clarifications_messages = state.get("clarification", [])

    context_parts = ["Please answer the following question:", current_question]
    
    if clarification_reasons_messages:
        context_parts.append("\nContext - Initial Unclarity & AI Reasoning:")
        for msg in clarification_reasons_messages:
            context_parts.append(f"- AI thought: {msg.content}")
            
    if user_clarifications_messages:
        context_parts.append("\nContext - User Clarifications:")
        for msg in user_clarifications_messages:
            context_parts.append(f"- User clarified: {msg.content}")

    prompt_content = "\n".join(context_parts)
    if DEBUG:
        print(f"DEBUG answer: Prompt content for answer: '{prompt_content}'")


    messages_for_answer = [
        SystemMessage(content="You are a helpful assistant that answers questions clearly. Use any provided context about previous unclarity and user clarifications to inform your answer to the current question."),
        HumanMessage(content=prompt_content)
    ]
    
    response = llm.invoke(messages_for_answer).content.strip()
    return {"answer": response}

In [8]:
def clarify(state: AgentState) -> AgentState:
    reason_messages = state.get("clarification_reason", [])
    reason_for_clarification = "The question was deemed unclear." # Default
    if reason_messages and hasattr(reason_messages[-1], 'content'):
        reason_for_clarification = reason_messages[-1].content # Get the last reason
        
    original_question = state["question"]

    print(f"🤖 AI: I'm a bit unsure about your question: \"{original_question}\"")
    print(f"🤖 AI Reason: {reason_for_clarification}")
    print("🤖 AI: Could you rephrase or provide more details?")
    clarified_input = input("🧑 You: ")

    revised_question = f"{original_question} (User Clarification: {clarified_input})"
    if DEBUG:
        print(f"DEBUG clarify: Revised question: '{revised_question}'")
    return {
        "question": revised_question,
        "clarification": [HumanMessage(content=clarified_input)]
    }

In [9]:
# ---- Graph Definition ----
builder = StateGraph(AgentState)

builder.add_node("is_clear_node", is_clear)
builder.add_node("clarify_node", clarify)
builder.add_node("answer_node", answer)

builder.set_entry_point("is_clear_node")
builder.add_conditional_edges(
    "is_clear_node",
    lambda x: x["clarity_status"],
    {"clear": "answer_node", "unclear": "clarify_node"}
)
builder.add_edge("clarify_node", "is_clear_node")
builder.add_edge("answer_node", END)

graph = builder.compile()

In [10]:
# ---- Save Graph Image (Only if DEBUG is True) ----
if DEBUG:
    try:
        with open("graph.png", "wb") as f:
            f.write(graph.get_graph().draw_mermaid_png())
        print("✅ Graph image saved as 'graph.png'")
    except Exception as e:
        print(f"⚠️ Could not save graph image: {e}. Pygraphviz and Graphviz might be required.")

In [11]:
# ---- Running the graph with DETAILED FLOW ----
initial_question = input("🧑 Ask a question: ")
initial_state = AgentState(
    question=initial_question,
    clarity_status="",
    clarification_reason=[], # CRITICAL: Initialize add_messages fields as empty lists
    clarification=[],      # CRITICAL: Initialize add_messages fields as empty lists
    answer=None
    # clarification_attempts=0 # if using
)

In [None]:

final_state_data = None

if DEBUG:
    print("\n🚀 Starting Graph Execution Flow (DEBUG MODE)...\n")
    cumulative_state_for_display = initial_state.copy() # For display purposes in debug mode

    for i, event_chunk in enumerate(graph.stream(initial_state, {"recursion_limit": 10})):
        for node_name, node_output_dict in event_chunk.items():
            print(f"--- Step {i+1}: Node Executed: '{node_name}' ---")
            print("  Node Output:")
            if not node_output_dict:
                print("    (No direct output values returned or node is END)")
            for key, value in node_output_dict.items():
                if isinstance(value, list) and all(isinstance(item, BaseMessage) for item in value):
                    print(f"    {key}:")
                    for msg_idx, msg in enumerate(value):
                        print(f"      [{msg_idx}] {type(msg).__name__}(content=\"{msg.content}\")")
                else:
                    print(f"    {key}: {value}")

            # Update display copy of the state (LangGraph handles true state)
            for key, value in node_output_dict.items():
                if key in cumulative_state_for_display:
                    if isinstance(cumulative_state_for_display.get(key), list) and \
                       isinstance(value, list) and \
                       key in AgentState.__annotations__ and \
                       "add_messages" in str(AgentState.__annotations__[key]):
                        cumulative_state_for_display[key].extend(value)
                    else:
                        cumulative_state_for_display[key] = value
            
            if node_name != END: # Don't print full state for END node as it might be empty
                print("  Current Relevant State (after this node's contribution):")
                print(f"    Question: \"{cumulative_state_for_display['question']}\"")
                print(f"    Clarity Status: \"{cumulative_state_for_display['clarity_status']}\"")

                if cumulative_state_for_display['clarification_reason']:
                    print("    Accumulated Clarification Reasons (AI):")
                    for msg_idx, msg in enumerate(cumulative_state_for_display['clarification_reason']):
                        print(f"      [{msg_idx}] {type(msg).__name__}(content=\"{msg.content}\")")
                else:
                    print("    Accumulated Clarification Reasons (AI): []")

                if cumulative_state_for_display['clarification']:
                    print("    Accumulated Clarifications (User):")
                    for msg_idx, msg in enumerate(cumulative_state_for_display['clarification']):
                        print(f"      [{msg_idx}] {type(msg).__name__}(content=\"{msg.content}\")")
                else:
                    print("    Accumulated Clarifications (User): []")
                
                if cumulative_state_for_display['answer']:
                    print(f"    Answer: \"{cumulative_state_for_display['answer']}\"")

            print("-" * 50)

            if node_name == END:
                print("\n🏁 Graph has reached END.")
                final_state_data = cumulative_state_for_display # Use the state we tracked
                break
        if final_state_data and node_name == END: # break outer loop if inner loop was broken by END
            break
else: # Normal operation (DEBUG is False)
    print("Processing your question...") # Minimal feedback
    final_state_data = graph.invoke(initial_state, {"recursion_limit": 10})

# ---- Final Output ----
if final_state_data and final_state_data.get("answer"):
    print("\n💡 AI Answer:", final_state_data["answer"])
elif final_state_data:
    print("\n🤔 AI: I couldn't arrive at a final answer for your question after the process.")
    if DEBUG and final_state_data.get('clarification_reason'):
        print("    Last reason for unclarity noted by AI:")
        for msg in final_state_data['clarification_reason']:
             print(f"      - {type(msg).__name__}(content=\"{msg.content}\")")
else:
    print("\n⚠️ An unexpected error occurred, and no final state was determined.") 

Processing your question...

💡 AI Answer: 1 + 2 = 3
