In [121]:
from __future__ import annotations

import os
from operator import add
from typing_extensions import TypedDict, Annotated
import yaml
from dotenv import load_dotenv
from langchain_core.runnables import RunnableSerializable, RunnableLambda, RunnableParallel, RunnablePassthrough, RunnablePick
from langchain_core.messages import AIMessage, SystemMessage, HumanMessage
from langchain_core.prompts import ChatPromptTemplate, SystemMessagePromptTemplate
from langchain_openai import ChatOpenAI
from langgraph.graph import StateGraph, MessagesState, START, END
from langgraph.graph.message import AnyMessage, add_messages
import openai
from pprint import pprint
from pydantic import BaseModel, Field

load_dotenv()

openai.api_key = os.getenv("OPENAI_API_KEY")

In [126]:
class Task(BaseModel):
    """`Task` object repesents a general but specific subproblem, decomoposed from the orginal `input_problem` of the user"""
    description: str = Field(..., description="Description with essential details of the task")
    

class SolvedTask(Task):
    """`SolvedTask` object represents a task that has been solved by the system"""
    solution: str = Field(..., description="The solution of the task")


class SchemaAgentInput(BaseModel):
    input_problem: str = Field(..., description="The original problem statement given by the user")
    task: Task = Field(..., description="The current task to be solved")
    task_history: list[SolvedTask] = Field([], description="List of tasks that have been solved so far")


class ThoughtProcess(SchemaAgentInput):
    """A `ThoughtProcess` object represents the thought process of the system"""

    steps: list[str] = []
    thought: Thought | None = None


class Thought(BaseModel):
    """A `Thought` object represents a distinct thought within the cognition of the system"""

    thought: AIMessage
    evaluation: float | None = Field(
        default=None,
        description="The evaluation of the thought. It can be a number between 0 and 1.0 being 0 the worst and 1.0 the best."
    )
    context: list[AnyMessage] = []
    children: list[Thought] = []

In [127]:
class ToTAgent:
    def __init__(self) -> None:
        with open("../prompts/tot_prompts.yaml") as f:
            self.prompts = yaml.safe_load(f)

        self.decomp_llm = ChatOpenAI(model="gpt-4o")

        # high temperature for more creative responses, low top_p for more likely responses
        # source: https://medium.com/@1511425435311/understanding-openais-temperature-and-top-p-parameters-in-language-models-d2066504684f
        self.cognition_llm = ChatOpenAI(model="gpt-4o", temperature=0.9, top_p=0.5)
        self.evaluation_llm = ChatOpenAI(model="gpt-4o", temperature=0.2)

    # --------------------------------------------------------------------------------

    def schema_setup(self, state: SchemaAgentInput) -> ThoughtProcess:
        class Steps(BaseModel):
            analysis: str = Field(..., description="Analysis of the task")
            steps: list[str]

        prompt = ChatPromptTemplate.from_messages(
                    [("system", self.prompts["system_prompt"]), 
                    ("user", self.prompts["task_decomposition_prompt"])]
                )
        llm = self.decomp_llm.with_structured_output(Steps)
        chain = prompt | llm | RunnableLambda(lambda x: getattr(x, "steps"))

        steps: list[str] = chain.invoke({
            "input_problem": state.input_problem,
            "task_history": state.task_history,
            "task": state.task
        })

        return {"steps": steps}
    
    # --------------------------------------------------------------------------------

    def cognition(self, state: ThoughtProcess) -> ThoughtProcess:
        cognition_chain = self._create_thought_generation_chain()
        evaluation_chain = self._create_evaluation_chain()
        
        tree = Thought(
            thought=AIMessage("<SEED>")
        )

        def cognition_walk(parent: Thought, level: int):
            child1: Thought = cognition_chain.invoke({
                "task": state.task,
                "step": state.steps[level],
                "context": parent.context
            })
            child2: Thought = cognition_chain.invoke({
                "task": state.task,
                "step": state.steps[level],
                "context": parent.context
            })
            return evaluation_chain.invoke({
                "task": state.task,
                "step": state.steps,
                "thoughts": [child1, child2]
            })
        
        
        return {"thoughts": [*cognition_walk(tree, 0)]}


    def _create_thought_generation_chain(self):
        chain = (
            ChatPromptTemplate.from_messages([
                ("system", self.prompts["thought_generation_system_prompt"]),
                ("placeholder", "{context}"),
                ("user", self.prompts["thought_generation_prompt"])
            ])
            | RunnableParallel(
                response=self.cognition_llm,
                context=RunnableLambda(lambda p: p.messages[-1:])
            )
            | RunnableParallel(
                response=RunnablePick(keys="response"),
                context=lambda x: x["context"] + [x["response"]]
            ) 
            | RunnableLambda(
                lambda x: Thought(
                    thought=x["response"],
                    context=x["context"]
                )
            )
        )
        return chain
    
    def _create_evaluation_chain(self):
        class EvalResults(BaseModel):
            evaluation_text: str = Field(..., description="Here you can think before answering")
            scores: list[float] = Field(..., description="List of scores [0-1], in the order of thoughts within the context")

        chain = (
            RunnableLambda(
                lambda x: {
                    **x,
                    "context": [t.thought for t in x["thoughts"]]
                } 
            )
            | RunnableParallel(
                evaluation=(
                    ChatPromptTemplate.from_messages(
                        [
                            ("user", self.prompts["evaluation_system_prompt"]),
                            ("placeholder", "{context}"),
                            ("user", self.prompts["evaluation_prompt"])
                        ]
                    ) 
                    | ChatOpenAI(model="gpt-4o", temperature=0.2).with_structured_output(EvalResults)
                ),
                thoughts=lambda x: x["thoughts"]
            )
            | RunnableParallel(
                evaluation=lambda x: x["evaluation"].scores,
                thoughts=lambda x: x["thoughts"]
            )
            | RunnableLambda(
                lambda x: [
                    Thought(
                        thought=x["thoughts"][i].thought,
                        evaluation=score,
                        context=x["thoughts"][i].context,
                        children=x["thoughts"][i].children
                    )
                    for i, score in enumerate(x["evaluation"])
                ]
            )
        )

        return chain

    # --------------------------------------------------------------------------------


In [68]:
with open("../prompts/tot_prompts.yaml") as f:
    PROMPTS = yaml.safe_load(f)

In [123]:
agent = ToTAgent()

workflow = StateGraph(ThoughtProcess, input=SchemaAgentInput)
workflow.add_node("setup", agent.schema_setup)
workflow.add_node("cognition", agent.cognition)
workflow.add_edge(START, "setup")
workflow.add_edge("setup", "cognition")

graph = workflow.compile()

In [124]:
response = graph.invoke(SchemaAgentInput(
    input_problem="is tomato puree and tomato sauce the same thing?",
    task=Task(description="Find the difference between tomato puree and tomato sauce"),
    task_history=[]
))

In [125]:
pprint(response)

{'input_problem': 'is tomato puree and tomato sauce the same thing?',
 'steps': ['Identify the ingredients and preparation process for tomato puree.',
           'Identify the ingredients and preparation process for tomato sauce.',
           'Compare the consistency and texture of tomato puree and tomato '
           'sauce.',
           'Determine the typical culinary uses of tomato puree and tomato '
           'sauce.'],
 'task': Task(description='Find the difference between tomato puree and tomato sauce'),
 'task_history': []}
