In [8]:
from dotenv import load_dotenv

_ = load_dotenv()

In [9]:
from langgraph.graph import StateGraph, END
from typing import TypedDict, Annotated, List
import operator
from langgraph.checkpoint.sqlite import SqliteSaver
from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, AIMessage, ChatMessage

memory = SqliteSaver.from_conn_string(":memory:")

In [10]:
class AgentState(TypedDict):
    task: str
    plan: str
    draft: str
    critique: str
    content: List[str]
    revision_number: int
    max_revisions: int

In [11]:
# from langchain_openai import ChatOpenAI
from langchain_google_genai import ChatGoogleGenerativeAI
# model = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
model = ChatGoogleGenerativeAI(model="gemini-2.0-flash-exp", temperature=0) 

In [12]:
PLAN_PROMPT = """You are an expert composer tasked with writing a high level outline of an document. \
Write such an outline for the user provided topic. Give an outline of the document along with any relevant notes \
or instructions for the sections."""

WRITER_PROMPT = """You are an document assistant tasked with writing excellent 5-paragraph documents.\
Generate the best document possible for the user's request and the initial outline. \
If the user provides critique, respond with a revised version of your previous attempts. \
Utilize all the information below as needed: 

------

{content}"""

REFLECTION_PROMPT = """You are a teacher grading an document submission. \
Generate critique and recommendations for the user's submission. \
Provide detailed recommendations, including requests for length, depth, style, etc."""

RESEARCH_PLAN_PROMPT = """You are a researcher charged with providing information that can \
be used when writing the following document. Generate a list of search queries that will gather \
any relevant information. Only create 3 queries max."""

RESEARCH_CRITIQUE_PROMPT = """You are a researcher charged with providing information that can \
be used when making any requested revisions (as outlined below). \
Generate a list of search queries that will gather any relevant information. Only create 3 queries max."""

In [13]:
from langchain_core.pydantic_v1 import BaseModel

class Queries(BaseModel):
    queries: List[str]


For example, replace imports like: `from langchain_core.pydantic_v1 import BaseModel`
with: `from pydantic import BaseModel`
or the v1 compatibility namespace if you are working in a code base that has not been fully upgraded to pydantic 2 yet. 	from pydantic.v1 import BaseModel

  exec(code_obj, self.user_global_ns, self.user_ns)


In [15]:
from tavily import TavilyClient
import os
tavily = TavilyClient(api_key=os.environ["TAVILY_API_KEY"])


In [16]:
def plan_node(state: AgentState):
    messages = [
        SystemMessage(content=PLAN_PROMPT), 
        HumanMessage(content=state['task'])
    ]
    response = model.invoke(messages)
    return {"plan": response.content}

In [17]:
def research_plan_node(state: AgentState):
    queries = model.with_structured_result(Queries).invoke([
        SystemMessage(content=RESEARCH_PLAN_PROMPT),
        HumanMessage(content=state['task'])
    ])
    content = state['content'] or []
    for q in queries.queries:
        response = tavily.search(query=q, max_results=2)
        for r in response['results']:
            content.append(r['content'])
    return {"content": content}

In [18]:
def generation_node(state: AgentState):
    content = "\n\n".join(state['content'] or [])
    user_message = HumanMessage(
        content=f"{state['task']}\n\nHere is my plan:\n\n{state['plan']}")
    messages = [
        SystemMessage(
            content=WRITER_PROMPT.format(content=content)
        ),
        user_message
        ]
    response = model.invoke(messages)
    return {
        "draft": response.content, 
        "revision_number": state.get("revision_number", 1) + 1
    }


In [19]:
def reflection_node(state: AgentState):
    messages = [
        SystemMessage(content=REFLECTION_PROMPT), 
        HumanMessage(content=state['draft'])
    ]
    response = model.invoke(messages)
    return {"critique": response.content}

In [20]:
def research_critique_node(state: AgentState):
    queries = model.with_structured_result(Queries).invoke([
        SystemMessage(content=RESEARCH_CRITIQUE_PROMPT),
        HumanMessage(content=state['critique'])
    ])
    content = state['content'] or []
    for q in queries.queries:
        response = tavily.search(query=q, max_results=2)
        for r in response['results']:
            content.append(r['content'])
    return {"content": content}

In [21]:
def should_continue(state):
    if state["revision_number"] > state["max_revisions"]:
        return END
    return "reflect"

In [26]:
Wentity = StateGraph(AgentState)

Wentity.add_node("planner", plan_node)
Wentity.add_node("create", generation_node)
Wentity.add_node("reflect", reflection_node)
Wentity.add_node("research_plan", research_plan_node)
Wentity.add_node("research_critique", research_critique_node)

Wentity.set_entry_point("planner")

Wentity.add_conditional_edges(
    "create", 
    should_continue, 
    {END: END, "reflect": "reflect"}
)

Wentity.add_edge("planner", "research_plan")
Wentity.add_edge("research_plan", "create")

Wentity.add_edge("reflect", "research_critique")
Wentity.add_edge("research_critique", "create")

<langgraph.graph.state.StateGraph at 0x127dfa26a50>

In [None]:
from langgraph.checkpoint.sqlite import SqliteSaver
with SqliteSaver.from_conn_string(":memory:") as checkpointer:
    graph = Wentity.compile(checkpointer=checkpointer)
    thread = {"configurable": {"thread_id": "1"}}
    for s in graph.stream({
        'task': "Hyper around Generative AI",
        "max_revisions": 2,
        "revision_number": 1,
    }, thread):
        print(s)
        

In [None]:
thread = {"configurable": {"thread_id": "1"}}
for s in graph.stream({
    'task': "what is the difference between langchain and langsmith",
    "max_revisions": 2,
    "revision_number": 1,
}, thread):
    print(s)