In [None]:
from utils.helpers import set_api_keys_env
set_api_keys_env()

#### Essay Writer Design

1. Plan ===> Research Plan
2. Research Plan ===> Generate
3. Generate
    a. if OK: end
    b. if not OK
        i. Generate ===> Reflect
        ii. Reflect ===> Research Critique
        iii. Research Critique ===> Generate (3)

In [None]:
from langgraph.graph import StateGraph, END
from typing import TypedDict, Annotated, List
import operator
from langgraph.checkpoint.memory import InMemorySaver
from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, ToolMessage
from rich.console import Console
from rich.markdown import Markdown
import warnings

In [None]:
class AgentState(TypedDict):
    task: str # This is the human input
    plan: str # Key to keep track of plan, planning agent will generate
    draft: str # draft of the essay
    critique: str # critique of the draft, critique agent
    content: List[str] # list of document that tavily has researched
    revision_number: int # number of revisions made to the draft
    max_revision: int # maximum number of revisions allowed

In [None]:
from langchain_openai import ChatOpenAI

model = ChatOpenAI(model="gpt-4o", temperature=0.0)

In [None]:
PLAN_PROMPT = """You are an expert writer tasked with writing high level outline of an essay. \
Write such an outline for the user provided topic. Give an outline of the essay along with any relevant notes \
    or instructions for the sections. \
"""

In [None]:
RESEARCH_PLAN_PROMPT = """You are a researcher charged with providing information that can \
be used when writing the following essay. Generate a list of search queries that will gather \
any relevant information. Only generate 3 queries max."""

In [None]:
WRITER_PROMPT = """You are an essay assistant tasked with writing excellent 5-paragraph essays. \
Generate the best essay possible fo rthe user's request and the initial outline. \
If the user provides critique, respond with a revised version of your previous attempts. \
Utilize all the information below as needed:

------
{content}"""

In [None]:
REFLECTION_PROMPT = """You are a teacher grading an essay submission. \
Generate critique and recommendations for the user's submission. \
Provide detailed recommendations, including requests for length, depth, style, etc."""

In [None]:
RESEARCH_CRITIQUE_PROMPT = """You are a researcher charged with providing information that can \
be used when making any requested revisions (as outlined below). \
Generate a list of search queries that will gather any releavant information. Only generate 3 queries max."""

In [None]:
from langchain_core.pydantic_v1 import BaseModel

class Queries(BaseModel):
    queries: List[str]

In [None]:
from tavily import TavilyClient

tavily = TavilyClient()

#### Nodes 

In [None]:
def plan_node(state: AgentState):
    messages = [
        SystemMessage(content=PLAN_PROMPT),
        HumanMessage(content=state["task"]),
    ]
    response = model.invoke(messages)
    return {"plan": response.content}

In [None]:
def research_plan_node(state:AgentState):
    queries = model.with_structured_output(Queries).invoke([
        SystemMessage(content=RESEARCH_PLAN_PROMPT),
        HumanMessage(content=state["task"]),
    ])
    content = state.get('content', [])
    for q in queries.queries:
        response = tavily.search(q, max_results=3)
        for result in response['results']:
            content.append(result['content'])
    return {"content": content}

In [None]:
def generate_node(state:AgentState):
    content = "\n\n".join(state['content'] or [])
    user_message = HumanMessage(content = f"{state['task']}\n\nHere is my plan:\n\n{state['plan']}")
    messages = [
        SystemMessage(
            content=WRITER_PROMPT.format(content=content)
        ),
        user_message
    ]
    resoponse = model.invoke(messages)
    return {
        "draft": resoponse.content, 
        "revision_number": state.get("revision_number", 1) + 1
    }

In [None]:
def reflection_node(state: AgentState):
    messages = [
        SystemMessage(content=REFLECTION_PROMPT),
        HumanMessage(content=state["draft"]),
    ]
    response = model.invoke(messages)
    return {"critique": response.content}

In [None]:
def research_critique_node(state: AgentState):
    queries = model.with_structured_output(Queries).invoke([
        SystemMessage(content=RESEARCH_CRITIQUE_PROMPT),
        HumanMessage(content=state["critique"]),
    ])
    content = state.get('content', [])
    for q in queries.queries:
        response = tavily.search(q, max_results=3)
        for result in response['results']:
            content.append(result['content'])
    return {"content": content}

In [None]:
def should_continue(state):
    if state['revision_number'] >= state['max_revision']:
        return END
    return "reflect"

In [None]:
def create_agent():
    builder= StateGraph(AgentState)
    inmem1 = InMemorySaver()
    builder.add_node("planner", plan_node)
    builder.add_node("research_plan", research_plan_node)
    builder.add_node("generate", generate_node)
    builder.add_node("reflect", reflection_node)
    builder.add_node("research_critique", research_critique_node)
    
    
    builder.set_entry_point("planner")

    builder.add_conditional_edges(
        "generate",
        should_continue,
        {END:END, "reflect": "reflect"}
    )
    builder.add_edge("planner", "research_plan")
    builder.add_edge("research_plan", "generate")
    
    builder.add_edge("reflect",  "research_critique")
    builder.add_edge("research_critique", "generate")

    agent = builder.compile(checkpointer=inmem1)



    return agent
    

In [None]:
create_agent()

In [None]:
mapping1 = {
    'planner': 'plan',
    'research_plan': 'content',
    'generate': 'draft',
    'reflect': 'critique',
    'research_critique':'content'
}

In [None]:
def write_essay_on_topic(topic: str, intermediate:bool=False):
    agent = create_agent()
    intermediates = {}
    thread1 = {"configurable": {"thread_id": "1"}}
    console = Console()
    console.print(Markdown(f'# Essay Topic: {topic}'))
    for s in agent.stream({
        "task": topic,
        "max_revision": 3,
        "revision_number": 1,
    }, thread1):
        node_name = list(s.keys())[0]
        node_content = s[node_name][mapping1.get(node_name)]
        console.print(Markdown(f'## {node_name.title()}'))
        console.print(Markdown('---'))
        if isinstance(node_content, str):
            console.print(Markdown(node_content))
        elif isinstance(node_content, list):
            for i, p_ in enumerate(node_content):
                console.print(Markdown(f"{i}. {p_}"))

    

In [None]:
with warnings.catch_warnings():
    warnings.simplefilter('ignore')
    write_essay_on_topic('Write an Essay about HSBC, its origin, business and the current scenario')
    
    