In [1]:
from dataclasses import dataclass

from agent_prompts import CREATE_USER_PROFILE_PROMPT, INVESTOR_ASSISTANT_PROMPT, RISK_ANALYST_PROMPT
from pydantic import BaseModel
from pydantic_ai import Agent, RunContext
from rich.prompt import Prompt

from knd.ai import get_messages_for_agent_tool, run_until_completion
from knd.prompts import ASSISTANT_PROMPT

In [2]:
MODEL = "google-gla:gemini-1.5-flash"


class UserProfile(BaseModel):
    name: str
    age: int
    investment_goals: str
    income: float
    # risk_tolerance: Literal["conservative", "moderate", "aggressive"]
    # savings_capacity: float = Field(description="The amount of money the user can save each month")
    # financial_knowledge_level: Literal["beginner", "intermediate", "advanced", "expert"]
    # investment_experience: str = Field(
    #     description="Detailed description of user's investment history and experience across different investment types"
    # )
    # tax_considerations: str
    # life_stage: str = Field(
    #     description="The current phase of life (e.g., Student, Young Professional, Parent, Pre-retirement, Retired, etc.) that impacts financial planning needs."
    # )
    # obligations: str = Field(
    #     description="Financial obligations and commitments such as debts, loans, mortgages, or dependent care responsibilities"
    # )


@dataclass
class FinDeps:
    user_profile: UserProfile | None = None


investor_assistant = Agent(MODEL, result_type=str, deps_type=FinDeps, system_prompt=INVESTOR_ASSISTANT_PROMPT)
user_profile_creator = Agent(
    MODEL,
    result_type=UserProfile | str,  # type:ignore
    deps_type=FinDeps,
    system_prompt=CREATE_USER_PROFILE_PROMPT,
)
risk_analyst = Agent(MODEL, result_type=str, deps_type=FinDeps, system_prompt=RISK_ANALYST_PROMPT)


@investor_assistant.system_prompt(dynamic=False)
def assistant_prompt() -> str:
    prompt = ASSISTANT_PROMPT
    # if ctx.deps.user_profile:
    #     prompt += f"\n\nUser Profile so far: {ctx.deps.user_profile.model_dump_json()}"
    return prompt


@investor_assistant.tool
async def create_user_profile(ctx: RunContext[FinDeps]) -> UserProfile:
    "Create a user profile. Must be called before anything else. Start with this tool."
    print("Creating user profile")
    user_prompt = "generate the user profile"
    messages = await get_messages_for_agent_tool(agent=user_profile_creator, user_prompt=user_prompt, ctx=ctx)  # type: ignore
    res = await run_until_completion(
        user_prompt=user_prompt,
        agent=user_profile_creator,  # type: ignore
        message_history=messages,
        deps=ctx.deps,
    )
    ctx.deps.user_profile = res.data  # type: ignore
    return res.data  # type: ignore


@investor_assistant.tool
async def risk_analysis(ctx: RunContext[FinDeps]) -> str:
    "Analyze the risk of the user's investment plan"
    user_prompt = "analyze the risk of the user's investment plan"
    messages = await get_messages_for_agent_tool(agent=risk_analyst, user_prompt=user_prompt, ctx=ctx)  # type: ignore
    res = await risk_analyst.run(user_prompt=user_prompt, deps=ctx.deps, message_history=messages)
    return res.data


In [None]:
message_history = None
user_prompt = "hello"
deps = FinDeps()
while user_prompt.lower() not in ["q", "quit", "exit"]:
    res = await investor_assistant.run(user_prompt=user_prompt, deps=deps, message_history=message_history)
    user_prompt = Prompt.ask(res.data)
    message_history = res.all_messages()

In [None]:
res.all_messages()