In [52]:
!pip install rich

Defaulting to user installation because normal site-packages is not writeable
Collecting rich
  Downloading rich-13.9.4-py3-none-any.whl.metadata (18 kB)
Collecting markdown-it-py>=2.2.0 (from rich)
  Downloading markdown_it_py-3.0.0-py3-none-any.whl.metadata (6.9 kB)
Collecting mdurl~=0.1 (from markdown-it-py>=2.2.0->rich)
  Downloading mdurl-0.1.2-py3-none-any.whl.metadata (1.6 kB)
Downloading rich-13.9.4-py3-none-any.whl (242 kB)
Downloading markdown_it_py-3.0.0-py3-none-any.whl (87 kB)
Downloading mdurl-0.1.2-py3-none-any.whl (10.0 kB)
Installing collected packages: mdurl, markdown-it-py, rich
Successfully installed markdown-it-py-3.0.0 mdurl-0.1.2 rich-13.9.4


In [1]:
from pydantic import BaseModel, Field
from dataclasses import dataclass
import datetime

from pydantic_ai import Agent, RunContext, ModelRetry
from typing import Literal

from pydantic_ai.usage import Usage, UsageLimits
from pydantic_ai.messages import ModelMessage

from rich.prompt import Prompt
import nest_asyncio

In [2]:
nest_asyncio.apply()

In [3]:
from dotenv import load_dotenv
import os

load_dotenv()

api_key = os.getenv("GEMINI_API_KEY")
print(api_key)

AIzaSyAwGSuvGnVhIG0oVyOh8pqsyN_Bi1eItSE


In [4]:
model='google-gla:gemini-1.5-flash'

In [5]:
class FlightDetails(BaseModel):
    """ Details of the most suitable flight."""
    flight_number: str
    price: int
    origin: str = Field(description="Three leter airport code")
    destination : str = Field(description="Three letter airport code")
    date: datetime.date


In [6]:
class NoFlightFound(BaseModel):
    """ When no valid flight is found. """

In [7]:
@dataclass
class Deps:
    web_page_text: str
    req_origin: str
    req_destination: str
    req_date: datetime.date

In [24]:
search_agent = Agent[Deps, FlightDetails | NoFlightFound](
    model,
    result_type=FlightDetails | NoFlightFound, # type: ignore
    retries=4,
    system_prompt=(
        "Your job is to find the cheapest flight for the user on the given date. "
    )
)

In [25]:
extraction_agent = Agent(
    model,
    result_type=list[FlightDetails],
    system_prompt="Extract all the flight details from the given text."
)

In [26]:
@search_agent.tool
async def extract_flights(ctx: RunContext[Deps]) -> list[FlightDetails]:
    """ Get details of all flights. """
    result = await extraction_agent.run(ctx.deps.web_page_text, usage=ctx.usage)
    return result.data

In [27]:
@search_agent.result_validator
async def validate_result(ctx: RunContext[Deps], result: FlightDetails | NoFlightFound) -> FlightDetails | NoFlightFound:
    """ Procedural validation that the flight meets the constraints. """
    if isinstance(result, NoFlightFound):
        return result

    errors: list[str] = []
    if result.origin != ctx.deps.req_origin:
        errors.append(
            f"Flight should have origin {ctx.deps.req_origin}, not {result.origin}"
        )
    if result.destination != ctx.deps.req_destination:
        errors.append(
            f"Flight should have destination {ctx.deps.req_destination}, not {result.destination}"
        )
    if result.date != ctx.deps.req_date:
        errors.append(
            f"Flight should be on {ctx.deps.req_date}, not {result.date}"
        )

    if errors:
        raise ModelRetry("\n.".join(errors))
    else:
        return result

In [28]:
class SeatPreference(BaseModel):
    row: int = Field(ge=1, le=30)
    seat: Literal["A", "B", "C", "D", "E", "F"]

In [29]:
class Failed(BaseModel):
    """ Unable to extract a seat selection."""

In [30]:
seat_preference_agent = Agent[
    None, SeatPreference | Failed
](
    model,
    result_type=SeatPreference | Failed,
    system_prompt=(
        "Extract the user's seat preference. "
        "Seats A and F are window seats. "
        "Row 1 is the front row and has extra leg room. "
        "Rows 14, and 20 also have extra leg room. "
    ),
)

In [31]:
flights_web_page = """
1. Flight SFO-AK123
- Price: $350
- Origin: San Francisco International Airport (SFO)
- Destination: Ted Stevens Anchorage International Airport (ANC)
- Date: January 10, 2025

2. Flight SFO-AK456
- Price: $370
- Origin: San Francisco International Airport (SFO)
- Destination: Fairbanks International Airport (FAI)
- Date: January 10, 2025

3. Flight SFO-AK789
- Price: $400
- Origin: San Francisco International Airport (SFO)
- Destination: Juneau International Airport (JNU)
- Date: January 20, 2025

4. Flight NYC-LA101
- Price: $250
- Origin: San Francisco International Airport (SFO)
- Destination: Ted Stevens Anchorage International Airport (ANC)
- Date: January 10, 2025

5. Flight CHI-MIA202
- Price: $200
- Origin: Chicago O'Hare International Airport (ORD)
- Destination: Miami International Airport (MIA)
- Date: January 12, 2025

6. Flight BOS-SEA303
- Price: $120
- Origin: Boston Logan International Airport (BOS)
- Destination: Ted Stevens Anchorage International Airport (ANC)
- Date: January 12, 2025

7. Flight DFW-DEN404
- Price: $150
- Origin: Dallas/Fort Worth International Airport (DFW)
- Destination: Denver International Airport (DEN)
- Date: January 10, 2025

8. Flight ATL-HOU505
- Price: $180
- Origin: Hartsfield-Jackson Atlanta International Airport (ATL)
- Destination: George Bush Intercontinental Airport (IAH)
- Date: January 10, 2025
"""

In [32]:
usage_limits = UsageLimits(request_limit=15)

In [33]:
deps = Deps(
    web_page_text=flights_web_page,
    req_origin="SFO",
    req_destination="ANC",
    req_date=datetime.date(2025, 1, 10)
)
deps.req_destination

'ANC'

In [34]:
message_history: list[ModelMessage] | None = None

In [35]:
usage: Usage = Usage()

In [36]:
async def find_seat(usage: Usage) -> SeatPreference:
    message_history: list[ModelMessage] | None = None
    while True:
        answer = Prompt.ask("What seat would you like?")

        result = await seat_preference_agent.run(
            answer,
            message_history=message_history,
            usage=usage,
            usage_limits=usage_limits,
        )

        if isinstance(result.data, SeatPreference):
            return result.data
        else:
            print("Could not understand seat preference. Please try again.")
            message_history = result.all_messages()

In [37]:
async def buy_tickets(flight_details: FlightDetails, seat: SeatPreference):
    print(f"Purchasing flight {flight_details=!r} {seat=!r}...")

In [40]:
result = search_agent.run_sync(
    f"Find me a flight from {deps.req_origin} to {deps.req_destination} on {deps.req_date}",
    deps=deps,
    usage=usage,
    message_history=message_history,
    usage_limits=usage_limits,
)

if isinstance(result.data, NoFlightFound):
    print("No flight found")
else:
    flight = result.data
    print(f"Flight found: {flight}")
    answer = Prompt.ask(
        "Do you want to buy this flight, or keep searching? (buy/*search)",
        choices=["buy", "search", ""],
        show_choices=False,
    )

    if answer == "buy":
        seat = await find_seat(usage)
        await buy_tickets(flight, seat)
    else:
        message_history = result.all_messages(
            result_tool_return_content="Please suggest another flight"
        )

Flight found: flight_number='NYC-LA101' price=250 origin='SFO' destination='ANC' date=datetime.date(2025, 1, 10)


 buy


 2


Could not understand seat preference. Please try again.


 1


Purchasing flight flight_details=FlightDetails(flight_number='NYC-LA101', price=250, origin='SFO', destination='ANC', date=datetime.date(2025, 1, 10)) seat=SeatPreference(row=1, seat='A')...


In [41]:
print(result.data)

flight_number='NYC-LA101' price=250 origin='SFO' destination='ANC' date=datetime.date(2025, 1, 10)
