<a href="https://colab.research.google.com/github/Dimildizio/DS_course/blob/main/Neural_networks/Agents/pydantic_agent.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
%%capture
!pip install 'pydantic-ai-slim[openai]'

# Step 1: Basic model connector


### Imports and env variables

First we import libs and create env variable for `OpenAIChatModel` to be able to get `OPENROUTER_API_KEY` from env variables

In [3]:
import os

from google.colab import userdata
from pydantic import BaseModel
from pydantic_ai import Agent, RunContext
from pydantic_ai.models.openai import OpenAIChatModel
from pydantic_ai.providers.openrouter import OpenRouterProvider


OPENROUTER_API_KEY = userdata.get('openrouter')
os.environ["OPENROUTER_API_KEY"] = OPENROUTER_API_KEY

### Initiating model and agent


We create an agent output scheme

In [4]:
class Result(BaseModel):
    answer: str
    confidence: float

> Write system prompt

> create model

> create agent

In [5]:
sys_prompt = "You're a study assistant agent. Your answer should be brief and structured. Avoid emoji and slang."

model = OpenAIChatModel("deepseek/deepseek-chat-v3.1:free",
                        provider=OpenRouterProvider())


agent = Agent[None, Result](model=model,
                            system_prompt=sys_prompt)

### Handling `agent.run` and memory

Simple async run function with an agent decorator. In Pydantic AI the context "lives" while the `agent.run` is being executed (unlike usual LLM request `model.complete` or `model.generate` when it doesn't have access to it's own history/tool use), however between requests there is no memory.

The agent has access to `ctx.deps` - like ram and `ctx.memory_history` - all prompts to the model in current run, after `agent.run` return result the session is closed and the memory is erased.

In [6]:
@agent.run
async def run(ctx: RunContext[None], question: str) -> Result:
    draft = await ctx.llm.complete(question)
    return Result(answer=draft.data, confidence=0.5)

### Try it

In [8]:
result = await agent.run("Hello, who are you?")
print(result)


AgentRunResult(output='I am an AI assistant designed to help answer questions and provide support. How can I assist you today?')


## Step 2: add tools

### More imports

In [9]:
from typing import Dict, Any, Optional, List
from pydantic import Field
import ast, operator as op

### Create schema for dependencies - data accessible to the Agent during agent.run
We create pydantic model so it would follow typisation, and get validated
When `agent.run` is executed `Deps` gets instantiated and is accessible via `ctx.deps`

In [36]:
class Deps(BaseModel):
    log: List[Dict[str, Any]] = Field(default_factory=list)

class Result(BaseModel):
    answer: str
    confidence: float

### Create a new system prompt and agent

In [37]:
sys_prompt = """You are a calculation-and-text utility agent.
Decide if a TOOL is needed. If so, emit a single JSON command with:
  {"tool": "count_word_chars"|"calc_expr"|"convert_units"|"none", "args": {...}}
Then STOP. The python runtime will execute the tool and show you its result.
After you see the tool RESULT, produce the final short answer.

Guidelines:
- If the user asks to count characters per word -> choose "count_word_chars" and pass ONLY the target text as args: {"text": "..."}.
- If it's a math expression (e.g., "3*(2+5)/7" or "2**10 + 0.5") -> choose "calc_expr" with {"expression": "..."}.
- If it's a unit conversion like "<number> <unit> to <unit>" -> choose "convert_units" with {"query": "..."}.
- Keep answers short and precise. If numeric, one line with the number and brief explanation.
- Do NOT invent data. Prefer tools when applicable.
"""


agent = Agent[Deps, Result](model=model,
                            system_prompt=sys_prompt)

### Tool 1: calculator

In [12]:
class CalcInput(BaseModel):
    expression: str

class CalcOutput(BaseModel):
    value: float

In [38]:
_ALLOWED_OPS = {
    ast.Add: op.add, ast.Sub: op.sub, ast.Mult: op.mul,
    ast.Div: op.truediv, ast.Pow: op.pow, ast.Mod: op.mod,
    ast.USub: op.neg, ast.UAdd: op.pos,
}

def _safe_eval(node):
    if isinstance(node, ast.Constant):
        if isinstance(node.value, (int, float)): return node.value
        raise ValueError("Only numeric constants allowed")
    if isinstance(node, ast.UnaryOp) and type(node.op) in _ALLOWED_OPS:
        return _ALLOWED_OPS[type(node.op)](_safe_eval(node.operand))
    if isinstance(node, ast.BinOp) and type(node.op) in _ALLOWED_OPS:
        return _ALLOWED_OPS[type(node.op)](_safe_eval(node.left), _safe_eval(node.right))
    if isinstance(node, ast.Expr): return _safe_eval(node.value)
    raise ValueError("Unsupported expression")


In [39]:
@agent.tool
def calc_expr(ctx: RunContext[Deps], data: CalcInput) -> CalcOutput:
    """Safely evaluate arithmetic expression: + - * / ** % and parentheses.
    Example: calc_expr(expression="3*(2+5)/7")
    """
    tree = ast.parse(data.expression, mode="eval")
    val = float(_safe_eval(tree.body))
    ctx.deps.log.append({"tool": "calc_expr", "expr": data.expression, "value": val})
    return CalcOutput(value=val)

### Tool 2: Converter

In [15]:
class ConvertInput(BaseModel):
    query: str  # "<num> <unit> to <target unit>"

class ConvertOutput(BaseModel):
    value: float
    from_unit: str
    to_unit: str

In [40]:
_UNIT_TO_BASE = {
    # length -> meters
    "m": ("m", 1.0), "meter": ("m", 1.0), "meters": ("m", 1.0),
    "cm": ("m", 0.01), "mm": ("m", 0.001), "km": ("m", 1000.0),
    "inch": ("m", 0.0254), "in": ("m", 0.0254), "inches": ("m", 0.0254),
    "ft": ("m", 0.3048), "foot": ("m", 0.3048), "feet": ("m", 0.3048),
    "yd": ("m", 0.9144), "yard": ("m", 0.9144), "yards": ("m", 0.9144),
    # mass -> kilograms
    "kg": ("kg", 1.0), "g": ("kg", 0.001),
    "lb": ("kg", 0.45359237), "lbs": ("kg", 0.45359237),
    "pound": ("kg", 0.45359237), "pounds": ("kg", 0.45359237),
}

def _parse_convert(q: str):
    parts = q.strip().lower().split()
    if "to" not in parts or len(parts) < 3:
        raise ValueError("Use format like: '12 inches to cm'")
    to_idx = parts.index("to")
    num = float(parts[0].replace(",", "."))
    from_unit = parts[1]; to_unit = parts[to_idx + 1]
    return num, from_unit, to_unit

In [41]:
@agent.tool
def convert_units(ctx: RunContext[Deps], data: ConvertInput) -> ConvertOutput:
    """Convert '<number> <from_unit> to <to_unit>' for length and mass.
    Example: convert_units(query="12 inches to cm")
    """
    num, fu, tu = _parse_convert(data.query)
    if fu not in _UNIT_TO_BASE or tu not in _UNIT_TO_BASE:
        raise ValueError("Unsupported units")
    fam_f, kf = _UNIT_TO_BASE[fu]; fam_t, kt = _UNIT_TO_BASE[tu]
    if fam_f != fam_t:
        raise ValueError(f"Incompatible unit families ({fu} -> {tu})")
    base_val = num * kf; out_val = base_val / kt
    ctx.deps.log.append({"tool": "convert_units", "query": data.query, "value": out_val})
    return ConvertOutput(value=out_val, from_unit=fu, to_unit=tu)

### Tool 3: Evaluator

Evaluation tool compares a numeric output with expected one, taking into account a margin of error

In [18]:
class EvalInput(BaseModel):
    produced: str
    expected: float
    tol: float = 1e-6  # error margin

class EvalOutput(BaseModel):
    ok: bool
    reason: str

In [42]:
@agent.tool
def evaluate_nums(ctx: RunContext[Deps], data: EvalInput) -> EvalOutput:
    """Compare a numeric answer in 'produced' to 'expected' within tolerance."""
    import re
    m = re.search(r"[-+]?\d*\.?\d+(?:[eE][-+]?\d+)?", data.produced.replace(",", "."))
    if not m:
        res = EvalOutput(ok=False, reason="no number found")
    else:
        got = float(m.group(0))
        res = EvalOutput(ok=abs(got - data.expected) <= data.tol, reason=f"got={got}, expected={data.expected}")
    ctx.deps.log.append({"tool": "evaluate_nums", **res.model_dump()})
    return res

### Tool 4: Letters counter

In [20]:
from collections import Counter
import re

In [35]:
class CharCountInput(BaseModel):
    text: str
    normalize: bool = True
    letters_only: bool = True

class CharCountPerWord(BaseModel):
    word: str
    counts: Dict[str, int]

class CharCountOutput(BaseModel):
    items: List[CharCountPerWord]

In [22]:
def _tokenize_words(text: str) -> List[str]:
    return re.findall(r"\b\w+\b", text, flags=re.UNICODE)

def format_char_counts(out: CharCountOutput) -> str:
    parts = []
    for item in out.items:
        counts_sorted = ", ".join(f"{ch}:{n}" for ch, n in sorted(item.counts.items()))
        parts.append(f"{item.word}:{{{counts_sorted}}}")
    return ", ".join(parts)

In [29]:
@agent.tool
def count_word_chars(ctx: RunContext[Deps], data: CharCountInput) -> CharCountOutput:
    """Count per-word character frequencies for the given text.
    Examples:
      count_word_chars(text="I'm strawberry")
      count_word_chars(text="I am a strawberry")
    Return a compact mapping per word.
    """
    text = (data.text or "").strip()
    if not text:
        return CharCountOutput(items=[])
    words = _tokenize_words(text)
    items: List[CharCountPerWord] = []
    for w in words:
        word_key = w.lower() if data.normalize else w
        letters = [
            (ch.lower() if data.normalize else ch)
            for ch in w
            if (ch.isalpha() if data.letters_only else True)
        ]
        c = Counter(letters)
        items.append(CharCountPerWord(word=word_key, counts=dict(c)))
    ctx.deps.log.append({"tool": "count_word_chars", "n_words": len(items)})
    return CharCountOutput(items=items)

### Main agent.run triage

  async def run(ctx: RunContext[Deps], question: str) -> Result:
