<a href="https://colab.research.google.com/github/Fuenfgeld/Agent_Tutorial_PydanticAI/blob/main/05_Agent_Self_Correction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%pip -q install pydantic-ai
%pip -q install nest_asyncio
%pip -q install devtools
%pip -q install logfire
%pip -q install aiosqlite

In [None]:
from google.colab import userdata
key = userdata.get('Claude')
keyOpenAI = userdata.get('openAI')
keyLogFire = userdata.get('logfire')

import os
from google.colab import userdata
from IPython.display import display, Markdown

os.environ["OPENAI_API_KEY"] = keyOpenAI
os.environ["ANTHROPIC_API_KEY"] = key
import nest_asyncio
nest_asyncio.apply()


In [None]:
import logfire
logfire.configure(token=keyLogFire)

In [None]:
import asyncio
import os
import aiosqlite
from sqlite3 import Connection
from colorama import Fore
from pydantic import BaseModel
from pydantic_ai import Agent, ModelRetry, RunContext
from pydantic_ai.models.openai import OpenAIModel

# Define the AI model to use for understanding natural language requests
model = OpenAIModel('gpt-4o-mini')

# Define the patient data structure using Pydantic for validation
class Patient(BaseModel):
    patient_id: int
    mrn: str  # Medical Record Number
    full_name: str
    diagnosis: str

# Create an AI agent that can interact with the database
agent = Agent(model=model,
              system_prompt="You are a clinical assistant. Help retrieve patient records using the available tools.",
              result_type=Patient,
              deps_type=Connection)

@logfire.instrument('get_patient_by_name')
@agent.tool(retries=10)
async def get_patient_by_name(ctx: RunContext[Connection], name: str) -> int:
    """Search for a patient record using their full name."""
    # Execute SQL query to find patient by name
    async with ctx.deps.cursor() as cursor:
        await cursor.execute(f'SELECT * FROM patients WHERE full_name="{name}"')
        query_result = await cursor.fetchall()
        if not query_result:
            print(Fore.YELLOW, f"Patient with name {name} not found.")
            raise ModelRetry(f"Patient with name {name} not found. Try a different spelling or format.")
    return query_result[0]

@logfire.instrument('get_patient_by_mrn')
@agent.tool(retries=3)
async def get_patient_by_mrn(ctx: RunContext[Connection], mrn: str) -> int:
    """Search for a patient record using their medical record number."""
    # Execute SQL query to find patient by MRN
    async with ctx.deps.cursor() as cursor:
        await cursor.execute(f'SELECT * FROM patients WHERE mrn="{mrn}"')
        query_result = await cursor.fetchall()
        if not query_result:
            raise ValueError(f"Patient with MRN {mrn} not found.")
    return query_result[0]

async def seed_db(conn):
    # Create and populate the database with sample patient data
    await conn.execute("DROP TABLE IF EXISTS patients")
    await conn.execute("CREATE TABLE IF NOT EXISTS patients (patient_id INTEGER PRIMARY KEY, mrn TEXT, full_name TEXT, diagnosis TEXT)")
    await conn.execute("INSERT INTO patients (mrn, full_name, diagnosis) VALUES ('MRN001', 'John Smith', 'Hypertension')")
    await conn.execute("INSERT INTO patients (mrn, full_name, diagnosis) VALUES ('MRN002', 'Jane Smith', 'Type 2 Diabetes')")
    await conn.execute("INSERT INTO patients (mrn, full_name, diagnosis) VALUES ('MRN003', 'Jim Smith', 'Asthma')")
    await conn.commit()

async def main():
    # Connect to database and run the patient lookup
    async with aiosqlite.connect("clinic.sqlite") as conn:
        await seed_db(conn)
        with logfire.span('Retrieving patient record') as span:
         try:
            # Intentionally misspelled name to demonstrate retry capability
            result = await agent.run('Find the patient record for Jimmy Smith', deps=conn)
            span.set_attribute('result', result.output)
            print(Fore.GREEN, f"Patient record: {result.output}")
         except ValueError as e:
            print(Fore.YELLOW, e)
            span.record_exception(e)
         except Exception as e:
            print(Fore.RED, e)
            span.record_exception(e)

if __name__ == "__main__":
    asyncio.run(main())