In [None]:
from langchain_anthropic import ChatAnthropic
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_ollama import ChatOllama
from langchain_openai import ChatOpenAI
from langchain_core.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate, MessagesPlaceholder
from langchain_core.messages import SystemMessage

from typing import TypedDict, List, Annotated, Literal, Union
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.runnables import RunnableConfig
import operator

from langgraph.types import Command, Send
from langgraph.graph import StateGraph, END, START
from langgraph.checkpoint.memory import MemorySaver

from prompts import *

from IPython.display import Image, display

from pydantic import BaseModel, Field
from typing import List, Literal, Dict, Any
from enum import Enum

import uuid

from tavily import TavilyClient

from dotenv import load_dotenv

load_dotenv()

True

In [3]:
def init_llm(
        provider: Literal["openai", "anthropic", "google", "ollama"],
        model: str,
        temperature: float = 0.5,
):
    if provider == "openai":
        return ChatOpenAI(model=model, temperature=temperature)
    elif provider == "anthropic":
        return ChatAnthropic(model=model, temperature=temperature)
    elif provider == "google":
        return ChatGoogleGenerativeAI(model=model, temperature=temperature)
    elif provider == "ollama":
        return ChatOllama(model=model, temperature=temperature)

In [4]:
llm = init_llm(
    provider="openai",
    model="gpt-4o-mini",
    temperature=0.5
)

In [5]:
class FieldType(str, Enum):
    string = "string"
    number = "number"
    array = "array"
    boolean = "boolean"

class SchemaField(BaseModel):
    key: str = Field(..., description="The unique identifier for the field")
    type: FieldType = Field(..., description="The data type of the field")
    description: str = Field(..., description="Some descriptive information for the field")

class DatasetSchema(BaseModel):
    generated_schema: list[SchemaField]

class DatasetRecords(BaseModel):
    dataset: List[Dict[str, Any]]

In [6]:
class Section(BaseModel):
    section_name: str = Field(..., description="The name of this section of the report without its number")
    sub_sections: List[str] = Field(..., description="Comprehensive descriptions of sub-sections, each combining the sub-section title and its bullet points into a fluid, natural-language description")

class Sections(BaseModel):
    sections: List[Section] = Field(..., description="A list of sections")

class Query(BaseModel):
    query: str = Field(..., description="A search query")

class Queries(BaseModel):
    queries: List[Query] = Field(..., description="A list of search queries")

class SearchResult(BaseModel):
    query: Query = Field(..., description="The search query that was used to retrieve the raw content")
    raw_content: list[str] = Field(..., description="The raw content retrieved from the search")

class Feedback(BaseModel):
    feedback: Union[str, bool] = Field(..., description="Feedback on the report structure. If the content is good for the section, return True (boolean), otherwise return a string of feedback on what is missing or incorrect.")

class SectionOutput(BaseModel):
    # final_section_content: List[str] = Field(..., description="The final section content")
    final_section_dataset: List[Dict[str, Any]] = Field(..., description="The final section dataset")

In [7]:
class AgentState(TypedDict):
    topic: str
    outline: str
    messages: Annotated[List[BaseMessage], operator.add]
    report_structure: str
    sections: List[Section]
    final_section_dataset: Annotated[List[Dict[str, Any]], operator.add] = []
    final_dataset: List[Dict[str, Any]]
    schema: DatasetSchema

class ResearchState(TypedDict):
    topic: str
    report_structure: str
    section: Section
    knowledge: str
    reflection_feedback: Feedback = Feedback(feedback="")
    generated_queries: List[Query] = []
    searched_queries: Annotated[List[Query], operator.add] = []
    search_results: Annotated[List[SearchResult], operator.add] = []
    accumulated_content: str = ""
    reflection_count: int = 1
    final_section_content: List[str] = []
    schema: DatasetSchema
    final_section_dataset: List[Dict[str, Any]] = []
    error: str

In [None]:
dataset_schema_generator_system_prompt = ChatPromptTemplate.from_messages([
    SystemMessagePromptTemplate.from_template(SCHEMA_GENERATION_PROMPT),
    HumanMessagePromptTemplate.from_template(
        template="""
        Topic: {topic}
        Outline: {outline}
        """
    ),
    MessagesPlaceholder(variable_name="messages")
])

llm_with_schema_tool = llm.bind_tools(tools=[DatasetSchema], tool_choice="required")
schema_generator_llm = dataset_schema_generator_system_prompt | llm_with_schema_tool

def schema_generator_node(state: AgentState, config: RunnableConfig):
    result = schema_generator_llm.invoke(state)
    suggested_schema = DatasetSchema.model_validate(result.tool_calls[0]["args"])

    return {"schema": suggested_schema, "messages": f"{[suggested_schema.generated_schema]}"}

In [10]:
def human_feedback_on_schema_node(state: AgentState, config: RunnableConfig) -> Command[Literal["report_structure_planner", "schema_generator"]]:
    human_message = input("Please provide feedback on the report structure (type 'continue' to continue): ")
    schema = state.get("schema")
    if human_message == "continue":
        return Command(
            goto="report_structure_planner",
            update={"messages": [HumanMessage(content=human_message)], "schema": schema}
        )
    else:
        return Command(
            goto="schema_generator",
            update={"messages": [HumanMessage(content=human_message)]}
        )

In [12]:
report_structure_planner_system_prompt = ChatPromptTemplate.from_messages([
    SystemMessagePromptTemplate.from_template(REPORT_STRUCTURE_PLANNER_SYSTEM_PROMPT_TEMPLATE),
    HumanMessagePromptTemplate.from_template(
        template="""
        Topic: {topic}
        Outline: {outline}
        """
    ),
    MessagesPlaceholder(variable_name="messages")
])

report_structure_planner_llm = report_structure_planner_system_prompt | llm

def report_structure_planner_node(state: AgentState, config: RunnableConfig):
    result = report_structure_planner_llm.invoke(state)
    return {"messages": [result]}

In [13]:
def human_feedback_node(state: AgentState, config: RunnableConfig)->Command[Literal["section_formatter", "report_structure_planner"]]:
    human_message = input("Please provide feedback on the report structure (type 'continue' to continue): ")
    report_structure = state.get("messages")[-1].content
    if human_message == "continue":
        return Command(
            goto="section_formatter",
            update={"messages": [HumanMessage(content=human_message)], "report_structure": report_structure}
        )
    else:
        return Command(
            goto="report_structure_planner",
            update={"messages": [HumanMessage(content=human_message)]}
        )

In [15]:

section_formatter_system_prompt = ChatPromptTemplate.from_messages([
    SystemMessagePromptTemplate.from_template(SECTION_FORMATTER_SYSTEM_PROMPT_TEMPLATE),
    HumanMessagePromptTemplate.from_template(template="{report_structure}"),
])

section_formatter_llm = section_formatter_system_prompt | llm.with_structured_output(Sections)

def section_formatter_node(state: AgentState, config: RunnableConfig) -> Command[Literal["research_agent"]]:
    result = section_formatter_llm.invoke(state)
    schema = state.get("schema")
    report_structure = state.get("report_structure")
    topic = state.get("topic")
    # return {"sections": result.sections}
    return Command(
        update={"sections": result.sections},
        goto=[
            Send(
                "research_agent",
                {
                    "topic": topic,
                    "section": s,
                    "schema": schema,
                    "report_structure": report_structure,
                }
            ) for s in result.sections
        ]
    )

In [17]:
section_knowledge_system_prompt = ChatPromptTemplate.from_messages([
    SystemMessagePromptTemplate.from_template(SECTION_KNOWLEDGE_SYSTEM_PROMPT_TEMPLATE),
    HumanMessagePromptTemplate.from_template(template="{section}"),
])

section_knowledge_llm = section_knowledge_system_prompt | llm

def section_knowledge_node(state: ResearchState, config: RunnableConfig):
    result = section_knowledge_llm.invoke(state)
    return {"knowledge": result.content}

In [None]:

def query_generator_node(state: ResearchState, config: RunnableConfig):
    query_generator_system_prompt = ChatPromptTemplate.from_messages([
        SystemMessagePromptTemplate.from_template(QUERY_GENERATOR_SYSTEM_PROMPT_TEMPLATE),
        HumanMessagePromptTemplate.from_template(template="Section: {section}\nPrevious Queries: {searched_queries}\nReflection Feedback: {reflection_feedback}"),
    ])

    query_generator_llm = query_generator_system_prompt | llm.with_structured_output(Queries)
    state.setdefault("reflection_feedback", "")
    state.setdefault("searched_queries", [])
    configurable = config.get("configurable")

    input_data = {
        **state,
        **configurable  # includes max_queries, search_depth, etc.
    }

    result = query_generator_llm.invoke(input_data, configurable)
    return {"generated_queries": result.queries, "searched_queries": result.queries}


In [20]:
tavily_client = TavilyClient()

def tavily_search_node(state: ResearchState, config: RunnableConfig):
    queries = state["generated_queries"]
    configurable = config.get("configurable")
    search_results = []
    for query in queries:
        raw_content = []
        response = tavily_client.search(query=query.query, max_results=configurable.get("search_depth"), include_raw_content=True)
        for result in response["results"]:
            raw_content.append(result['content'])
        search_results.append(SearchResult(query=query, raw_content=raw_content))
    return {"search_results": search_results}

In [None]:

def result_accumulator_node(state: ResearchState, config: RunnableConfig):
    result_accumulator_system_prompt = ChatPromptTemplate.from_messages([
        SystemMessagePromptTemplate.from_template(RESULT_ACCUMULATOR_SYSTEM_PROMPT_TEMPLATE),
        HumanMessagePromptTemplate.from_template(template="{search_results}"),
    ])

    result_accumulator_llm = result_accumulator_system_prompt | llm
    result = result_accumulator_llm.invoke(state)
    return {"accumulated_content": result.content}

In [24]:
reflection_feedback_system_prompt = ChatPromptTemplate.from_messages([
    SystemMessagePromptTemplate.from_template(REFLECTION_FEEDBACK_SYSTEM_PROMPT_TEMPLATE),
    HumanMessagePromptTemplate.from_template(template="Section: {section}\nAccumulated Content: {accumulated_content}"),
])

reflection_feedback_llm = reflection_feedback_system_prompt | llm.with_structured_output(Feedback)

def reflection_feedback_node(state: ResearchState, config: RunnableConfig) -> Command[Literal["final_section_formatter", "query_generator"]]:
    reflection_count = state.get("reflection_count", 0)
    configurable = config.get("configurable")
    result = reflection_feedback_llm.invoke(state)
    feedback = result.feedback
    if (feedback == True) or (feedback.lower() == "true") or (reflection_count < configurable.get("num_reflections")):
        return Command(
            update={"reflection_feedback": feedback},
            goto="final_section_formatter"
        )
    else:
        return Command(
            update={"reflection_feedback": feedback, "reflection_count": reflection_count + 1},
            goto="query_generator"
        )

In [26]:
final_section_formatter_system_prompt = ChatPromptTemplate.from_messages([
    SystemMessagePromptTemplate.from_template(FINAL_SECTION_FORMATTER_SYSTEM_PROMPT_TEMPLATE),
    HumanMessagePromptTemplate.from_template(template="Internal Knowledge: {knowledge}\nSearch Result content: {accumulated_content}"),
])

final_section_formatter_llm = final_section_formatter_system_prompt | llm

def final_section_formatter_node(state: ResearchState, config: RunnableConfig):
    result = final_section_formatter_llm.invoke(state)
    return {"final_section_content": result.content}

In [27]:
def process_datagen_prompt(fields: List[SchemaField], rows: int = 10) -> str:
    schema_instruction = {field.key: field.description for field in fields}

    field_string = f"""## Response Format
Always respond with a valid JSON array of objects:
[
{json.dumps(schema_instruction, indent=2)},
// Additional entries...
]
"""
    return f"""
You are an expert Question-Answer generation assistant who has the skills of a polymath. Your task is to analyze content provided by the user and generate a comprehensive set of questions with detailed answers based on that content.

## Core Instructions

1. When presented with content, carefully analyze it to identify key concepts, important details, practical applications, and potential challenges or edge cases.

2. Generate a diverse set of questions and answers that thoroughly cover the provided content. Your response must be in valid JSON format.

3. Format code properly within JSON strings, using appropriate escape characters for special characters.

4. Number of dataset rows must be {rows}

{field_string}
"""

In [28]:
import time
import json
from openai import RateLimitError, OpenAIError
from pydantic import ValidationError

def final_section_dataset_generator_node(state: ResearchState, config: RunnableConfig, max_retries: int = 3, base_wait: float = 2.0):
    schema = state.get("schema")
    max_rows = config.get("configurable").get("max_rows_from_each_section")
    FINAL_SECTION_DATASET_GENERATION_PROMPT = process_datagen_prompt(schema.generated_schema, int(max_rows))

    final_section_dataset_generator_prompt = ChatPromptTemplate.from_messages([
        SystemMessage(content=FINAL_SECTION_DATASET_GENERATION_PROMPT),
        HumanMessagePromptTemplate.from_template(template="Report Structure: {report_structure}\nSection Contents: {final_section_content}"),
    ])
    final_dataset_generator_llm = final_section_dataset_generator_prompt | llm

    for attempt in range(max_retries):
        try:
            result = final_dataset_generator_llm.invoke(state)
            raw_text = result.content

            # Clean up markdown wrapping
            if raw_text.startswith("```json"):
                raw_text = raw_text[len("```json"):].lstrip()
            elif raw_text.startswith("```"):
                raw_text = raw_text[len("```"):].lstrip()
            if raw_text.endswith("```"):
                raw_text = raw_text[:-3].rstrip()

            parsed_json = json.loads(raw_text)
            final_package = {"dataset": parsed_json}
            validated = DatasetRecords(**final_package)

            return {"final_section_dataset": validated.dataset}

        except json.JSONDecodeError as e:
            print(f"[JSON Parse Error] {e}")
            return {"final_section_dataset": [], "error": "JSONDecodeError"}

        except ValidationError as e:
            print(f"[Pydantic Validation Error] {e}")
            return {"final_section_dataset": [], "error": "ValidationError"}

        except RateLimitError:
            wait_time = base_wait * (2 ** attempt)
            print(f"[Rate Limit] Retrying in {wait_time}s (Attempt {attempt + 1}/{max_retries})...")
            time.sleep(wait_time)

        except OpenAIError as e:
            print(f"[OpenAI Error] {e}")
            wait_time = base_wait * (2 ** attempt)
            time.sleep(wait_time)

        except Exception as e:
            print(f"[Unexpected Error] {e}")
            return {"final_section_dataset": [], "error": str(e)}

    return {"final_section_dataset": [], "error": "Max retries exceeded"}

In [29]:
def final_dataset_aggregator_node(state: AgentState, config: RunnableConfig):
    dataset = []
    section_datasets = state.get("final_section_dataset")

    for section_dataset in section_datasets:
            dataset.append(section_dataset)
    
    return {"final_dataset": dataset}

In [None]:
research_builder = StateGraph(ResearchState, output=SectionOutput)

research_builder.add_node("section_knowledge", section_knowledge_node)
research_builder.add_node("query_generator", query_generator_node)
research_builder.add_node("tavily_search", tavily_search_node)
research_builder.add_node("result_accumulator", result_accumulator_node)
research_builder.add_node("reflection", reflection_feedback_node)
research_builder.add_node("final_section_formatter", final_section_formatter_node)
research_builder.add_node("final_section_dataset_generator", final_section_dataset_generator_node)

research_builder.add_edge(START, "section_knowledge")
research_builder.add_edge("section_knowledge", "query_generator")
research_builder.add_edge("query_generator", "tavily_search")
research_builder.add_edge("tavily_search", "result_accumulator")
research_builder.add_edge("result_accumulator", "reflection")
research_builder.add_edge("final_section_formatter", "final_section_dataset_generator")
research_builder.add_edge("final_section_dataset_generator", END)

memory_saver = MemorySaver()

builder = StateGraph(AgentState)

builder.add_node("schema_generator", schema_generator_node)
builder.add_node("human_feedback_on_schema", human_feedback_on_schema_node)
builder.add_node("report_structure_planner", report_structure_planner_node)
builder.add_node("human_feedback_report_structure", human_feedback_node)
builder.add_node("section_formatter", section_formatter_node)
builder.add_node("research_agent", research_builder.compile())
builder.add_node("final_dataset_aggregator", final_dataset_aggregator_node)

builder.set_entry_point("schema_generator")
builder.add_edge("schema_generator", "human_feedback_on_schema")
builder.add_edge("report_structure_planner", "human_feedback_report_structure")
builder.add_edge("research_agent", "final_dataset_aggregator")
builder.add_edge("final_dataset_aggregator", END)

In [None]:
graph = builder.compile(checkpointer=memory_saver)
display(Image(graph.get_graph(xray=1).draw_mermaid_png()))

In [None]:
import json
from datetime import datetime
import os

TOPIC = "Support Vector Machines"
OUTLINE = "I want to have qna dataset on this topic A-Z so that the model I train would be able to answer everythin about support vector machines"

thread = {
    "configurable": {
        "thread_id": str(uuid.uuid4()),
        "max_queries": 2,
        "search_depth": 1,
        "num_reflections": 2,
        "max_rows_from_each_section": 5
    }
}

for event in graph.stream(
    {"topic": TOPIC, "outline": OUTLINE},
    config=thread,
):
    if "schema_generator" in event:
        print("<<< SCHEMA GENERATOR >>>")
        print(event["schema_generator"]["schema"])
        print("\n", "="*100, "\n")
    elif "report_structure_planner" in event:
        print("<<< REPORT STRUCTURE PLANNER >>>")
        print(event["report_structure_planner"]["messages"][-1].content)
        print("\n", "="*100, "\n")
    elif "section_formatter" in event:
        print("<<< SECTION FORMATTING >>>")
        print(event["section_formatter"])
        print("\n", "="*100, "\n")
    elif "research_agent" in event:
        # check output of research_agent
        print("<<< RESEARCH AGENT >>>")
        print(event["research_agent"])
        print("\n", "="*100, "\n")
    elif "final_dataset_aggregator" in event:
        # check output of final_dataset_aggregator
        print("<<< FINAL REPORT WRITER >>>")
        print(event["final_dataset_aggregator"])
        print("\n", "=" * 100, "\n")

        output_data = event["final_dataset_aggregator"]
        
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        filename = f"final_dataset_output_{timestamp}.json"
        
        output_dir = "output_files"
        os.makedirs(output_dir, exist_ok=True)

        filepath = os.path.join(output_dir, filename)
        
        with open(filepath, "w", encoding="utf-8") as f:
            json.dump(output_data, f, indent=2, ensure_ascii=False)

        print(f"Saved final dataset to: {filepath}")
    elif "human_feedback_on_schema" in event:
        print("<<< HUMAN FEEDBACK ON SCHEMA >>>")
        print(event["human_feedback_on_schema"]["messages"][-1].content)
        print("\n", "="*100, "\n")
    elif "human_feedback_report_structure" in event:
        print("<<< HUMAN FEEDBACK ON REPORT STRUCTURE >>>")
        print(event["human_feedback_report_structure"]["messages"][-1].content)
        print("\n", "="*100, "\n")