In [None]:
# full_react_agent.py
import os
import json
import re
import ast
import sys
import traceback
from typing import Any, Optional, Dict, Callable

# Disable LangSmith logging (prevents unwanted 403 errors)
# os.environ["LANGCHAIN_TRACING_V2"] = "false"
# os.environ["LANGCHAIN_API_KEY"] = ""
# os.environ["LANGCHAIN_ENDPOINT"] = ""
# os.environ["LANGCHAIN_PROJECT"] = ""

from dotenv import load_dotenv
load_dotenv()

import requests

# Watsonx imports
from ibm_watsonx_ai import APIClient, Credentials
from ibm_watsonx_ai.foundation_models import ModelInference
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams

# LangGraph + messages
from langgraph.graph import StateGraph, START, END
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, ToolMessage

# typing for AgentState
from typing import Annotated, List, TypedDict, Optional
import operator


# -----------------------------------------------------------
# Agent State
# -----------------------------------------------------------
class AgentState(TypedDict):
    messages: Annotated[List[BaseMessage], operator.add]
    action_name: Optional[str]
    action_input: Optional[dict]
    raw_output: Optional[str]


# -----------------------------------------------------------
# ReAct Prompt
# -----------------------------------------------------------
REACT_PROMPT = """
You are a specialized agent. Your goal is to answer the user's request.
You have access to the following tool: {tool_name} with the following description: {tool_description}

You must respond in one of two formats:

1. Final Answer:
Thought: I have enough information to answer the user.
Action: Final Answer
Action Input: The final answer goes here.

2. Tool Call:
Thought: I need to use the tool to find the answer.
Action: {tool_name}
Action Input: {{"query": "the search term goes here"}}

Begin.
"""


# -----------------------------------------------------------
# Watsonx LLM Client (Llama 3.3 70B instruct)
# -----------------------------------------------------------
def get_watsonx_llm(
    model_id: str = "meta-llama/llama-3-3-70b-instruct",
    max_new_tokens: int = 512,
    temperature: float = 0.2,
) -> ModelInference:

    api_key = os.getenv("WATSONX_API_KEY")
    url = os.getenv("WATSONX_URL")
    project_id = os.getenv("WATSONX_PROJECT_ID")

    if not api_key or not url or not project_id:
        raise RuntimeError("Missing Watsonx environment variables.")

    creds = Credentials(api_key=api_key, url=url)
    client = APIClient(credentials=creds)

    return ModelInference(
        model_id=model_id,
        api_client=client,
        project_id=project_id,
        params={
            GenParams.MAX_NEW_TOKENS: max_new_tokens,
            GenParams.TEMPERATURE: temperature,
        },
    )


# -----------------------------------------------------------
# REAL Tavily Search Tool
# -----------------------------------------------------------
def tavily_search(query: str, top_k: int = 5) -> dict:
    """
    Real Tavily search call.
    """
    api_key = os.getenv("TAVILY_API_KEY")
    endpoint = os.getenv("TAVILY_ENDPOINT")

    if not api_key or not endpoint:
        raise RuntimeError("Missing Tavily environment variables.")

    headers = {
        "Authorization": f"Bearer {api_key}",
        "Content-Type": "application/json",
    }

    payload = {"query": query, "top_k": top_k}
    resp = requests.post(endpoint, headers=headers, json=payload, timeout=25)
    resp.raise_for_status()

    return resp.json()


# -----------------------------------------------------------
# Transcript Generator (Watsonx Llama)
# -----------------------------------------------------------
def generate_transcript(topic: str, llm: Optional[ModelInference] = None) -> str:
    if llm is None:
        llm = get_watsonx_llm()

    prompt = f"""
Generate a clear, spoken-style transcript for a YouTube trending video for kids about:
"{topic}"

Requirements:
- 400–700 words
- Conversational tone
- Clear structure: hook → points → ending
- No directions or metadata, only spoken words
"""

    response = llm.generate_text(prompt)

    if isinstance(response, dict) and "results" in response:
        return response["results"][0].get("generated_text", "")
    elif isinstance(response, str):
        return response
    return str(response)


# -----------------------------------------------------------
# LLM Node (ReAct)
# -----------------------------------------------------------
def react_llm_node(
    state: AgentState,
    tool_name: str,
    tool_description: str,
    model_id: str = "meta-llama/llama-3-3-70b-instruct",
):
    conversation = ""
    for msg in state["messages"]:
        if isinstance(msg, HumanMessage):
            conversation += f"User: {msg.content}\n"
        elif isinstance(msg, AIMessage):
            conversation += f"Assistant: {msg.content}\n"
        elif isinstance(msg, ToolMessage):
            conversation += f"Tool: {msg.content}\n"

    prompt = REACT_PROMPT.format(
        tool_name=tool_name,
        tool_description=tool_description
    ) + "\nConversation:\n" + conversation

    llm = get_watsonx_llm(model_id=model_id)
    response = llm.generate_text(prompt)

    if isinstance(response, dict) and "results" in response:
        raw_text = response["results"][0]["generated_text"]
    elif isinstance(response, str):
        raw_text = response
    else:
        raw_text = str(response)

    return {"messages": [AIMessage(content=raw_text)], "raw_output": raw_text}


# -----------------------------------------------------------
# Parse & Decide Router (ReAct)
# -----------------------------------------------------------
def parse_and_decide(state: AgentState):
    last = state["messages"][-1]
    if not isinstance(last, AIMessage):
        return {}

    text = last.content or ""

    print("\n--- Parsing ReAct Output ---")
    print(text[:600], "...\n")

    m = re.search(
        r"Action:\s*(.+?)\s*\nAction Input:\s*(.+)",
        text,
        re.DOTALL | re.MULTILINE
    )
    if not m:
        return {"action_name": None, "action_input": None}

    action_name = m.group(1).strip()
    action_input_raw = m.group(2).strip()

    try:
        action_input = json.loads(action_input_raw)
    except:
        try:
            action_input = ast.literal_eval(action_input_raw)
        except:
            action_input = {"raw": action_input_raw}

    if action_name.lower() == "final answer":
        return {"action_name": None, "action_input": action_input}

    return {
        "action_name": action_name,
        "action_input": action_input,
    }


# -----------------------------------------------------------
# TOOL EXECUTOR with correct ToolMessage(tool_call_id=...)
# -----------------------------------------------------------
TOOLS: Dict[str, Callable[..., Any]] = {
    "tavily_search": tavily_search,
}

def tool_executor_node(state: AgentState):
    name = state.get("action_name")
    data = state.get("action_input") or {}

    if not name:
        return {}

    tool_fn = TOOLS.get(name)
    if not tool_fn:
        result_text = f"Unknown tool: {name}"
    else:
        try:
            if isinstance(data, dict) and "query" in data:
                result = tool_fn(query=data["query"], **{k: v for k, v in data.items() if k != "query"})
            else:
                result = tool_fn(**data)
            result_text = json.dumps(result, ensure_ascii=False, indent=2)
        except Exception as e:
            result_text = json.dumps({"error": str(e), "trace": traceback.format_exc()})

    # FIX: ToolMessage requires tool_call_id
    tool_msg = ToolMessage(
        content=result_text,
        tool_call_id=name
    )

    return {"messages": [tool_msg], "action_name": None, "action_input": None}


# -----------------------------------------------------------
# Extract Final Answer
# -----------------------------------------------------------
# def extract_final_answer(state: AgentState) -> Optional[str]:
#     for msg in reversed(state["messages"]):
#         if isinstance(msg, AIMessage) and "Final Answer" in msg.content:
#             m = re.search(r"Action Input:\s*(.+)$", msg.content, re.S)
#             if m:
#                 return m.group(1).strip()
#             return msg.content
#     return None


def extract_final_answer(state):
    # Access the list of messages inside state["messages"]
    messages = state["messages"]["messages"]

    for msg in reversed(messages):
        if isinstance(msg, AIMessage) and "Final Answer" in msg.content:
            # Extract everything after "Final Answer"
            m = re.search(r"Final Answer\s*(.*)", msg.content, re.S)
            if m:
                return m.group(1).strip()
            return msg.content.strip()

    return None


# -----------------------------------------------------------
# Build LangGraph
# -----------------------------------------------------------
def build_graph():
    graph = StateGraph(AgentState)

    graph.add_node(
        "react_llm",
        lambda s: react_llm_node(
            s,
            tool_name="tavily_search",
            tool_description="Searches the internet and returns relevant info via Tavily."
        ),
    )

    graph.add_node("router", parse_and_decide)
    graph.add_node("tool_executor", tool_executor_node)

    graph.add_edge(START, "react_llm")
    graph.add_edge("react_llm", "router")

    def route_fn(state: AgentState):
        return "end" if state.get("action_name") is None else "tool"

    graph.add_conditional_edges(
        "router",
        route_fn,
        {"end": END, "tool": "tool_executor"}
    )

    graph.add_edge("tool_executor", "react_llm")
    return graph


# -----------------------------------------------------------
# MAIN
# -----------------------------------------------------------
def main():
    graph = build_graph()
    app = graph.compile()

    user_query = input("Enter your query/topic: ").strip()

    state: AgentState = {
        "messages": [HumanMessage(content=user_query)],
        "action_name": None,
        "action_input": None,
        "raw_output": None,
    }

    result = app.invoke(state)

    final = extract_final_answer(result)
    if final:
        print("\n==== Final Answer ====\n")
        print(final)

        print("\n==== Generating Transcript ====\n")
        # topic = user_query
        transcript = generate_transcript(final)
        print(transcript)
    else:
        print("\n==== No Final Answer Found ====\n")
        for m in result["messages"]:
            print(type(m).__name__, ":", m.content)


if __name__ == "__main__":
    main()