In [1]:
from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
from uuid import UUID, uuid4

from loguru import logger
from pydantic_ai import Agent, RunContext
from pydantic_ai import messages as _messages
from pydantic_graph import BaseNode, End, Graph, GraphRunContext

from knd.memory import AgentExperience, AgentMemories, Memory, Profile, UserSpecificExperience

%load_ext autoreload
%autoreload 2

In [2]:
SUMMARY_LIMIT = 20_000
MESSAGE_COUNT_LIMIT = 20
MEMORIES_DIR = Path("memories")

In [3]:
user_id = uuid4()
user_id = UUID("db5fe6ca-55ae-4f38-9acf-62d707a46041")

In [4]:
@dataclass
class GraphState:
    user_id: UUID
    agent: Agent[AgentMemories, str]
    agent_memories: AgentMemories
    memory_agent: Agent


@dataclass
class Chat(BaseNode[GraphState]):
    user_prompt: str

    async def run(self, ctx: GraphRunContext[GraphState]) -> CreateUserSpecificExperience:
        user_prompt = self.user_prompt
        message_history = ctx.state.agent_memories.message_history
        while user_prompt.lower() not in ["q", "quit", "exit"]:
            res = await ctx.state.agent.run(
                user_prompt=user_prompt, deps=ctx.state.agent_memories, message_history=message_history
            )
            message_history = res.all_messages()
            user_prompt = input(f"{res.data}    (q to quit)> ")
        return CreateUserSpecificExperience(message_history=message_history)


@dataclass
class CreateUserSpecificExperience(BaseNode[GraphState]):
    message_history: list[_messages.ModelMessage]

    async def run(self, ctx: GraphRunContext[GraphState]) -> CreateAgentExperience:
        logger.info(
            f"Creating user specific experience for Agent {ctx.state.agent_memories.agent_name} and User {ctx.state.user_id}"
        )
        profile_res = await ctx.state.memory_agent.run(
            user_prompt=Profile.user_prompt(), result_type=Profile, message_history=self.message_history
        )
        profile = profile_res.data
        memories_res = await ctx.state.memory_agent.run(
            user_prompt=Memory.user_prompt(), result_type=list[Memory], message_history=self.message_history
        )
        memories = memories_res.data
        if ctx.state.agent_memories.user_specific_experience:
            memories = ctx.state.agent_memories.user_specific_experience.memories + memories
        summary_res = await ctx.state.memory_agent.run(
            user_prompt=SUMMARY_PROMPT, result_type=str, message_history=self.message_history
        )
        summary = summary_res.data
        user_specific_experience = UserSpecificExperience(
            profile=profile, memories=memories, summary=summary, message_history=self.message_history
        )
        ctx.state.agent_memories.user_specific_experience = user_specific_experience
        return CreateAgentExperience(message_history=self.message_history)


@dataclass
class CreateAgentExperience(BaseNode[GraphState]):
    message_history: list[_messages.ModelMessage]

    async def run(self, ctx: GraphRunContext[GraphState]) -> Memorize:
        logger.info(
            f"Creating agent experience for Agent {ctx.state.agent_memories.agent_name} and User {ctx.state.user_id}"
        )
        agent_experience_res = await ctx.state.memory_agent.run(
            user_prompt=AgentExperience.user_prompt(),
            result_type=AgentExperience,
            message_history=self.message_history,
        )
        agent_experience = agent_experience_res.data
        ctx.state.agent_memories.agent_experience = agent_experience
        return Memorize()


@dataclass
class Memorize(BaseNode[GraphState, None, None]):
    async def run(self, ctx: GraphRunContext[GraphState]) -> End:
        ctx.state.agent_memories.dump(user_id=user_id, memories_dir=MEMORIES_DIR)
        return End(None)


In [4]:
agent_name = "anime_fan2"

anime_agent = Agent(
    name=agent_name,
    model="google-gla:gemini-2.0-flash-exp",
    system_prompt="You are an anime fan. The user will also be an anime fan. Just bros chilling talking about anime.",
    deps_type=AgentMemories,
    result_type=str,
)


In [5]:
anime_agent.name

'anime_fan2'

In [5]:
anime_graph = Graph(nodes=[Chat, CreateUserSpecificExperience, CreateAgentExperience, Memorize])

In [6]:
agent_name = "anime_fan2"

anime_agent = Agent(
    name=agent_name,
    model="google-gla:gemini-2.0-flash-exp",
    system_prompt="You are an anime fan. The user will also be an anime fan. Just bros chilling talking about anime.",
    deps_type=AgentMemories,
    result_type=str,
)


@anime_agent.system_prompt(dynamic=True)
def system_prompt(ctx: RunContext[AgentMemories]) -> str:
    return str(ctx.deps)


memory_agent = Agent(name="memory_agent", model="google-gla:gemini-1.5-flash")

graph_state = GraphState(
    user_id=user_id,
    agent=anime_agent,
    agent_memories=AgentMemories.load(
        agent_name=agent_name, user_id=user_id, memories_dir=MEMORIES_DIR, message_count_limit=MESSAGE_COUNT_LIMIT
    ),
    memory_agent=memory_agent,
)


In [None]:
_, history = await anime_graph.run(Chat(user_prompt="hi"), state=graph_state)

In [None]:
[item.data_snapshot() for item in history]

In [None]:
graph_state.agent_memories.message_history