In [1]:
from dataclasses import dataclass, field
from typing import Literal

from pydantic import BaseModel
from pydantic_ai import Agent, RunContext
from rich.prompt import Prompt

from knd.prompts import ASSISTANT_PROMPT
from knd.result_types import TaskSpecificExperience


In [2]:
MODEL = "openai:gpt-4o-mini"


class Profile(BaseModel):
    name: str
    age: int
    gender: Literal["male", "female", "other"]


@dataclass
class Task:
    tse: TaskSpecificExperience | None = None
    inner_monologue: list[str] = field(default_factory=list)


agent = Agent(
    MODEL,
    result_type=Profile | str,  # type: ignore
    deps_type=Task,
    system_prompt="Generate a profile for the user. The profile should have name, age, gender. If you need more information, ask the user. Once the profile is complete, return the profile.",
)


@agent.tool
async def contemplate(ctx: RunContext[Task], inner_monologue: str):
    "Contemplate and write down your thoughts in the inner monologue. Must be called before each step"
    ctx.deps.inner_monologue.append(inner_monologue)
    # return f"Inner Monologue: {ctx.deps.inner_monologue}"


@agent.tool
async def extract_task_specific_experience(ctx: RunContext[Task], tse: TaskSpecificExperience):
    "Extract and save the task specific experience. MUST be called before the profile is returned using the `final_result` tool"
    ctx.deps.tse = tse
    # return "Task Specific Experience extracted and saved"


@agent.tool
def get_weather(
    ctx: RunContext[Task],
    day: Literal["Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday"],
    location: str,
) -> str:
    "Get the weather at a day in a location"
    print(f"\n\nMessage history: {ctx.messages}")
    if day == "Monday" and location.lower() == "blackpool":
        return "It's raining"
    elif day == "Tuesday" and location.lower() == "london":
        return "It's sunny"
    elif day == "Wednesday" and location.lower() == "manchester":
        return "It's cloudy"
    elif day == "Thursday" and location.lower() == "new york":
        return "It's snowing"
    else:
        return "It's overcast"


@agent.result_validator  # type: ignore
async def result_validator(ctx: RunContext[Task], result: Profile | str) -> Profile | str:
    print(f"\n\nMessage history in result validator: {ctx.messages}")
    print(f"\n\nFunction Tools: {agent._function_tools}")
    print(f"\n\nUsage: {ctx.usage}")
    return result


@agent.system_prompt(dynamic=True)
def assistant_prompt(ctx: RunContext[Task]) -> str:
    prompt = ASSISTANT_PROMPT
    if ctx.deps.tse:
        prompt += f"\n\nTask Specific Experience So Far: {ctx.deps.tse.model_dump_json()}"
    if ctx.deps.inner_monologue:
        prompt += f"\n\nInner Monologue So Far: {ctx.deps.inner_monologue}"
    return prompt


In [None]:
message_history = None
user_prompt = "hello"
deps = Task()
while True:
    res = await agent.run(user_prompt=user_prompt, deps=deps, message_history=message_history)
    profile = res.data
    if isinstance(profile, str):
        user_prompt = Prompt.ask(profile)
        message_history = res.all_messages()
    else:
        break

In [None]:
res.all_messages()

In [None]:
deps.inner_monologue

In [None]:
deps.tse.model_dump()