In [None]:
import os
from typing import TypedDict, Annotated, Literal
from langgraph.graph import StateGraph, END
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_core.tools import tool
from langchain_core.runnables import RunnableLambda
import operator

# Set your OpenAI API key
# os.environ["OPENAI_API_KEY"] = "your-api-key-here"

# Sample USA economy data (as of July 2025)
usa_economy_docs = [
    "The US economy is projected to expand at a pace of 1.6% year-over-year in 2025, down from 2.8% in 2024. Source: The Conference Board",
    "US current-account deficit widened to $450.2 billion in Q1 2025. Source: BEA",
    "Despite much lower tariffs, the US economy is still expected to grow at a slower rate in 2025 compared with the previous two years. Source: Deloitte",
    "Several key economic predictions for 2025 include weaker US economic growth. Source: S&P Global",
    "This year began with high expectations but following strong 2.8% year-over-year growth in GDP. Source: NRF"
]

# Set up in-memory vectorstore for RAG
embeddings = OpenAIEmbeddings()
vectorstore = FAISS.from_texts(usa_economy_docs, embeddings)
retriever = vectorstore.as_retriever(search_kwargs={"k": 3})

# LLM instance
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)

# Define state
class AgentState(TypedDict):
    query: str
    category: Literal["general", "usa_economy", "live_news"]
    output: str
    validation_result: bool
    feedback: str
    attempts: Annotated[int, operator.add]

# Router node: Classify the query
def router_node(state: AgentState) -> dict:
    classify_prompt = ChatPromptTemplate.from_template(
        """Classify the following query into one of these categories:
        - general: for general queries not related to USA economy or live news
        - usa_economy: if the query is about USA economy or anything related to USA
        - live_news: if the query asks for live or current information, news, or real-time data
        
        Query: {query}
        
        Output only the category name."""
    )
    chain = classify_prompt | llm
    response = chain.invoke({"query": state["query"]})
    category = response.content.strip().lower()
    return {"category": category}

# General LLM node
def general_node(state: AgentState) -> dict:
    prompt_template = "Answer the query: {query}"
    if state.get("feedback"):
        prompt_template += "\nPrevious attempt failed with feedback: {feedback}. Correct your answer accordingly."
    prompt = ChatPromptTemplate.from_template(prompt_template)
    chain = prompt | llm
    inputs = {"query": state["query"]}
    if state.get("feedback"):
        inputs["feedback"] = state["feedback"]
    output = chain.invoke(inputs).content
    return {"output": output, "feedback": ""}  # Clear feedback after use

# USA Economy RAG node
def usa_rag_node(state: AgentState) -> dict:
    docs = retriever.invoke(state["query"])
    context = "\n".join([doc.page_content for doc in docs])
    prompt_template = "Answer the query using the following context:\n{context}\n\nQuery: {query}"
    if state.get("feedback"):
        prompt_template += "\nPrevious attempt failed with feedback: {feedback}. Correct your answer accordingly."
    prompt = ChatPromptTemplate.from_template(prompt_template)
    chain = prompt | llm
    inputs = {"query": state["query"], "context": context}
    if state.get("feedback"):
        inputs["feedback"] = state["feedback"]
    output = chain.invoke(inputs).content
    return {"output": output, "feedback": ""}  # Clear feedback after use

# Live News tool node (mocked for demo; in real use, integrate a news API)
@tool
def get_live_news(query: str) -> str:
    """Fetch live news related to the query."""
    # Mock implementation; replace with real API call, e.g., using requests to news API
    return f"Live news as of July 2025 for '{query}': [Mock] US stock market is up 2% today. Economy shows signs of recovery."

live_news_tool = RunnableLambda(get_live_news)

def live_news_node(state: AgentState) -> dict:
    prompt_template = "Use the live news tool to answer: {query}"
    if state.get("feedback"):
        prompt_template += "\nPrevious attempt failed with feedback: {feedback}. Correct your answer accordingly."
    # For simplicity, directly call the tool; in full agent, can use langchain agent
    news = live_news_tool.invoke(state["query"])
    prompt = ChatPromptTemplate.from_template(prompt_template)
    chain = prompt | llm
    inputs = {"query": state["query"]}
    if state.get("feedback"):
        inputs["feedback"] = state["feedback"]
    # Incorporate news into the prompt actually
    full_prompt = prompt_template + "\nLive news data: {news}"
    full_prompt = ChatPromptTemplate.from_template(full_prompt)
    chain = full_prompt | llm
    inputs["news"] = news
    output = chain.invoke(inputs).content
    return {"output": output, "feedback": ""}  # Clear feedback after use

# Validation node
def validation_node(state: AgentState) -> dict:
    validate_prompt = ChatPromptTemplate.from_template(
        """Validate if the following output correctly and completely answers the query.
        Query: {query}
        Output: {output}
        
        Respond with 'Yes' if valid, 'No' if not. If No, provide feedback on why it's invalid and how to correct it."""
    )
    chain = validate_prompt | llm
    response = chain.invoke({"query": state["query"], "output": state["output"]}).content
    if "Yes" in response:
        return {"validation_result": True}
    else:
        # Extract feedback (assuming the response includes explanation after 'No')
        feedback = response.split("No")[1].strip() if "No" in response else "Invalid output."
        attempts = state.get("attempts", 0) + 1
        return {"validation_result": False, "feedback": feedback, "attempts": attempts}

# Conditional routing after router
def route_after_router(state: AgentState) -> str:
    category = state["category"]
    print(f"Category: {category}")
    if category == "usa_economy":
        return "usa_rag"
    elif category == "live_news":
        return "live_news"
    else:
        return "general"

# Conditional after validation
def decide_after_validation(state: AgentState) -> str:
    if state["validation_result"] or state.get("attempts", 0) > 3:
        return "end"
    else:
        category = state["category"]
        if category == "usa_economy":
            return "usa_rag"
        elif category == "live_news":
            return "live_news"
        else:
            return "general"

# Build the graph
workflow = StateGraph(AgentState)

workflow.add_node("router", router_node)
workflow.add_node("general", general_node)
workflow.add_node("usa_rag", usa_rag_node)
workflow.add_node("live_news", live_news_node)
workflow.add_node("validate", validation_node)

workflow.set_entry_point("router")

workflow.add_conditional_edges(
    "router",
    route_after_router,
    {"general": "general", "usa_rag": "usa_rag", "live_news": "live_news"}
)

workflow.add_edge("general", "validate")
workflow.add_edge("usa_rag", "validate")
workflow.add_edge("live_news", "validate")

workflow.add_conditional_edges(
    "validate",
    decide_after_validation,
    {"general": "general", "usa_rag": "usa_rag", "live_news": "live_news", "end": END}
)

# Compile the graph
app = workflow.compile()

# Example usage
if __name__ == "__main__":
    initial_state = {"query": "What is stock price of tesla", "attempts": 0, "validation_result": False, "feedback": ""}
    result = app.invoke(initial_state)
    print("Final Output:", result["output"])

Final Output: I don't have access to live data or the ability to check current stock prices. However, you can easily find the latest stock price of Tesla by checking financial news websites, stock market apps, or platforms like Google Finance or Yahoo Finance.
