
# Gemini + Pydantic AI + Short-Term Memory (Notebook Demo)

This notebook shows how to:
- Maintain **short-term conversation memory** with a small in-memory store
- Build a **Pydantic AI** Agent using **Gemini**
- Run a mini loop that reuses memory between turns

> Before running the agent cell, set `GOOGLE_API_KEY` in your environment (or `os.environ` in a cell).


In [1]:

# %% Optional: install deps (uncomment if you need)
# %pip install pydantic-ai google-generativeai httpx
# If you're using the "slim" build & tools:
# %pip install "pydantic-ai-slim"


In [2]:

from __future__ import annotations
from typing import List, Dict, Optional, Any
from datetime import datetime, timedelta

# ---- Minimal STM (In-memory) ----
class Turn:
    def __init__(self, role: str, content: str, meta: Optional[Dict[str, Any]] = None):
        self.role = role
        self.content = content
        self.meta = meta or {}
        self.created_at = datetime.utcnow().isoformat()

    def to_dict(self):
        return {"role": self.role, "content": self.content, "meta": self.meta, "created_at": self.created_at}

def estimate_tokens(text: str) -> int:
    if not text:
        return 0
    return max(1, int(len(text) / 4))

class InMemoryShortTermMemory:
    def __init__(self, max_turns_per_session: int = 500, default_ttl_sec: int = 3600):
        self.store: Dict[str, List[Dict[str, Any]]] = {}
        self.expiry: Dict[str, datetime] = {}
        self.max_turns = max_turns_per_session
        self.default_ttl = default_ttl_sec

    def _expired(self, sid: str) -> bool:
        exp = self.expiry.get(sid)
        return exp is not None and datetime.utcnow() > exp

    def _gc(self, sid: str):
        if self._expired(sid):
            self.store.pop(sid, None)
            self.expiry.pop(sid, None)

    def append_turn(self, session_id: str, role: str, content: str, meta: Optional[Dict[str, Any]] = None, ttl_sec: Optional[int] = None):
        self._gc(session_id)
        lst = self.store.setdefault(session_id, [])
        lst.append(Turn(role, content, meta).to_dict())
        if len(lst) > self.max_turns:
            del lst[0: len(lst) - self.max_turns]
        self.expiry[session_id] = datetime.utcnow() + timedelta(seconds=ttl_sec or self.default_ttl)

    def get_all(self, session_id: str) -> List[Dict[str, Any]]:
        self._gc(session_id)
        return list(self.store.get(session_id, []))

    def get_by_token_budget(self, session_id: str, max_tokens: int = 2000) -> List[Dict[str, Any]]:
        turns = self.get_all(session_id)
        acc = []
        running = 0
        for t in reversed(turns):
            t_tokens = estimate_tokens(t["content"])
            if running + t_tokens > max_tokens and acc:
                break
            if running + t_tokens > max_tokens and not acc:
                acc.append(t); break
            acc.append(t); running += t_tokens
        return list(reversed(acc))
    
    def to_prompt_lines(self, session_id: str, max_tokens: int = 2000) -> str:
        lines = []
        for t in self.get_by_token_budget(session_id, max_tokens=max_tokens):
            lines.append(f"[{t['role']}] {t['content']}")
        return "\n".join(lines)

stm = InMemoryShortTermMemory()
print("STM ready.")


STM ready.


In [15]:

import os
from pydantic_ai import Agent, RunContext, ModelSettings
from pydantic_ai.models.google import GoogleModelSettings

# Use a light model for speed; change to gemini-1.5-pro if you prefer
MODEL_ID = "google-gla:gemini-2.5-pro"
if os.getenv("GOOGLE_API_KEY") is None:
    print("⚠️ Set GOOGLE_API_KEY in your environment before running the next cell.")

# --- Agent definition ---
ai_data_analyst = Agent[str, str](
    MODEL_ID,
    deps_type=str,            # we'll pass a 'session_id' as the dependency
    output_type=str,
    instructions=(
        "You are a precise data-analyst assistant. "
        "Use the provided memory recap to stay consistent with the user's prior context. "
        "When you need clarification, ask concise follow-ups. "
        "Prefer structured, bullet-point answers with clear steps or queries."
    ),
    # model_settings=ModelSettings(temperature=0.2, max_tokens=500),
)

@ai_data_analyst.system_prompt
def memory_recap(ctx: RunContext[str]) -> str:
    """Inject a short memory recap built from STM for this session."""
    session_id = ctx.deps  # deps is our session_id string
    recap = stm.to_prompt_lines(session_id, max_tokens=1200)
    if recap.strip():
        return ("Conversation memory (most recent first consolidated for context). "
                "Use for reference only, don't repeat verbatim.\n\n" + recap)
    else:
        return "No prior memory for this session."

# A tiny tool to show how you'd hook DB/metrics calls.
@ai_data_analyst.tool
def echo_last_user(ctx: RunContext[str], n: int = 1) -> str:
    """Return the last n user turns from STM (for debugging/demo)."""
    session_id = ctx.deps
    turns = [t for t in stm.get_all(session_id) if t["role"] == "user"]
    if not turns:
        return "No user turns yet."
    sel = turns[-n:]
    return "\n".join(f"{i+1}. {t['content']}" for i, t in enumerate(sel))

print("Agent ready:", MODEL_ID)


Agent ready: google-gla:gemini-2.5-pro


In [16]:
from typing import Optional
from pydantic_ai.models.google import GoogleModelSettings

async def ask_async(session_id: str, user_question: str, google_settings: Optional[GoogleModelSettings] = None) -> str:
    # record user turn
    stm.append_turn(session_id, "user", user_question)

    # run the agent asynchronously so we don't collide with Jupyter's event loop
    result = await ai_data_analyst.run(
        user_question,
        deps=session_id,
        model_settings=google_settings or GoogleModelSettings(temperature=0.2)
    )

    # record assistant turn
    stm.append_turn(
        session_id,
        "assistant",
        result.output,
        # meta={"model": result.model, "usage": result.usage().model_dump()}
    )

    # usage = result.usage()
    # print(f"[model={result.model_name}] tokens in={usage.input_tokens} out={usage.output_tokens}")
    return result.output


In [18]:
session_id = "demo-session-001"

print("Q1")
print(await ask_async(session_id, "My name is TJ and my favourite food is lasagna."))

print("\nQ2")
print(await ask_async(session_id, "What did I say my name was?"))

print("\nQ3 (uses tool)")
print(await ask_async(session_id, "What is my favourite food?"))


Q1
Got it, TJ. Lasagna is a great choice

Q2
You said your name is TJ.

Q3 (uses tool)
You said your favorite food is lasagna.
