Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ def airflow_toolset_to_langchain_tools(
works regardless of how the agent handles tool errors. Raising instead would
abort the run under ``create_agent``'s default tool-error handling.

The retry message is bounded by the tool's ``max_retries``: a tool that keeps
raising ``ModelRetry`` (for example an unrecoverable connection error) stops
being fed back and propagates once the budget is exhausted, so the run fails
instead of looping forever. The count resets after a successful call.

The toolset's ``get_tools`` is invoked eagerly here to enumerate the tools.

.. warning::
Expand Down Expand Up @@ -148,20 +153,39 @@ def _validate(kwargs: dict[str, Any]) -> dict[str, Any]:
# the args unchanged; a typed one coerces them (e.g. "5" -> 5).
return toolset_tool.args_validator.validate_python(kwargs)

# ModelRetry is a "feed this back to the model and retry" signal, so the bridge
# returns its message as the tool output instead of raising (see docstring).
# Bound it the way native pydantic-ai does, via the tool's max_retries: a tool
# that keeps raising ModelRetry (e.g. an unrecoverable connection error) must
# eventually propagate so the run fails rather than looping forever. The count
# resets on the first successful call.
max_retries = toolset_tool.max_retries if toolset_tool.max_retries is not None else 1
retries = {"count": 0}

def _handle_retry(error: ModelRetry) -> str:
retries["count"] += 1
if retries["count"] > max_retries:
# Reset before propagating so a reused tool starts the next run with a
# fresh budget instead of staying permanently exhausted.
retries["count"] = 0
raise error
return str(error)

def _sync_call(**kwargs: Any) -> Any:
try:
return _run_coro_sync(toolset.call_tool(name, _validate(kwargs), ctx, toolset_tool))
result = _run_coro_sync(toolset.call_tool(name, _validate(kwargs), ctx, toolset_tool))
except ModelRetry as e:
# ModelRetry is a "feed this back to the model and retry" signal, not a
# failure. Return the message as the tool output so the model self-corrects
# (see docstring); raising would abort under create_agent's default handling.
return str(e)
return _handle_retry(e)
retries["count"] = 0
return result

async def _async_call(**kwargs: Any) -> Any:
try:
return await toolset.call_tool(name, _validate(kwargs), ctx, toolset_tool)
result = await toolset.call_tool(name, _validate(kwargs), ctx, toolset_tool)
except ModelRetry as e:
return str(e)
return _handle_retry(e)
retries["count"] = 0
return result

return structured_tool_cls.from_function(
func=_sync_call,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from __future__ import annotations

import json
import sqlite3
from contextlib import suppress
from typing import TYPE_CHECKING, Any

Expand Down Expand Up @@ -76,31 +75,6 @@
"required": ["sql"],
}

_POSTGRES_RETRYABLE_EXCEPTIONS: tuple[type[Exception], ...] = ()
with suppress(ImportError):
import psycopg2.errors as _psycopg2_errors

_POSTGRES_RETRYABLE_EXCEPTIONS += (
_psycopg2_errors.UndefinedColumn,
_psycopg2_errors.UndefinedTable,
)

with suppress(ImportError):
from psycopg import errors as _psycopg3_errors

_POSTGRES_RETRYABLE_EXCEPTIONS += (
_psycopg3_errors.UndefinedColumn,
_psycopg3_errors.UndefinedTable,
)

_SQLALCHEMY_RETRYABLE_EXCEPTIONS: tuple[type[Exception], ...] = ()
with suppress(ImportError):
from sqlalchemy.exc import (
ProgrammingError as _SQLAlchemyProgrammingError,
)

_SQLALCHEMY_RETRYABLE_EXCEPTIONS = (_SQLAlchemyProgrammingError,)


class SQLToolset(AbstractToolset[Any]):
"""
Expand All @@ -112,6 +86,13 @@ class SQLToolset(AbstractToolset[Any]):
Uses a :class:`~airflow.providers.common.sql.hooks.sql.DbApiHook` resolved
lazily from the given ``db_conn_id``.

When a tool fails, the database's error message is returned to the agent as a
retry (:class:`pydantic_ai.ModelRetry`) so the model can correct its SQL within
the run instead of failing the task. ``pydantic-ai`` bounds this by the tool's
``max_retries``, so an unrecoverable error -- a bad connection or an auth
failure -- exhausts the retries and fails the task for Airflow to retry. The
toolset does not inspect the error type or message.

:param db_conn_id: Airflow connection ID for the database.
:param allowed_tables: Restrict which tables the agent can discover via
``list_tables`` and ``get_schema``. ``None`` (default) exposes all tables
Expand Down Expand Up @@ -243,15 +224,27 @@ async def call_tool(
ctx: RunContext[Any],
tool: ToolsetTool[Any],
) -> Any:
if name == "list_tables":
return self._list_tables()
if name == "get_schema":
return self._get_schema(tool_args["table_name"])
if name == "query":
return self._query(tool_args["sql"])
if name == "check_query":
if name not in ("list_tables", "get_schema", "query", "check_query"):
raise ValueError(f"Unknown tool: {name!r}")
try:
if name == "list_tables":
return self._list_tables()
if name == "get_schema":
return self._get_schema(tool_args["table_name"])
if name == "query":
return self._query(tool_args["sql"])
return self._check_query(tool_args["sql"])
raise ValueError(f"Unknown tool: {name!r}")
except Exception as e:
# Hand the database's own error back to the agent as a retry so it can
# read the message and fix its SQL within the run. pydantic-ai bounds
# this by the tool's max_retries, so an unrecoverable error (a bad
# connection, an auth failure) exhausts the budget and fails the task
# for Airflow to retry, rather than being silently worked around.
raise ModelRetry(
f"The {name} tool failed: {e}\n"
"Use the list_tables and get_schema tools to inspect the database, "
"then fix the query and try again."
) from e

# ------------------------------------------------------------------
# Tool implementations
Expand Down Expand Up @@ -317,14 +310,7 @@ def _query(self, sql: str) -> str:
allow_read_only_metadata=True,
)

try:
rows = hook.get_records(sql)
except Exception as e:
if self._is_retryable_query_error(hook, e):
raise ModelRetry(
f"error: {e!s}, Use get_schema and list_tables tools for more details."
) from e
raise
rows = hook.get_records(sql)
# Fetch column names from cursor description.
col_names: list[str] | None = None
if hook.last_description:
Expand All @@ -343,24 +329,6 @@ def _query(self, sql: str) -> str:
output["max_rows"] = self._max_rows
return json.dumps(output, default=str)

@staticmethod
def _is_retryable_query_error(hook: DbApiHook, error: Exception) -> bool:
check_error = getattr(error, "orig", error)
conn_type = getattr(hook, "conn_type", None)
if conn_type == "postgres":
return bool(_POSTGRES_RETRYABLE_EXCEPTIONS) and isinstance(
check_error, _POSTGRES_RETRYABLE_EXCEPTIONS
)
if conn_type == "sqlite":
if isinstance(check_error, sqlite3.OperationalError):
message = str(check_error).lower()
return "no such column" in message or "no such table" in message
return False
if _SQLALCHEMY_RETRYABLE_EXCEPTIONS and isinstance(error, _SQLALCHEMY_RETRYABLE_EXCEPTIONS):
return True
# TODO: Add support for other databases.
return False

def _check_query(self, sql: str) -> str:
# Resolve the dialect best-effort: if the connection can't be reached we
# still syntax-check dialect-agnostically rather than reporting invalid.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,40 @@ def test_model_retry_returned_as_tool_output_sync(self):

assert boom.invoke({}) == "fix your input and try again"

def test_repeated_model_retry_propagates_then_resets(self):
# A tool that keeps raising ModelRetry must not loop forever: once the tool's
# max_retries (1 here) is exhausted, the error propagates so the run fails
# instead of the bridge feeding the message back indefinitely. The budget then
# resets so a reused tool is not poisoned for the next run.
boom = {t.name: t for t in airflow_toolset_to_langchain_tools(FakeToolset())}["boom"]

assert boom.invoke({}) == "fix your input and try again" # fed back
with pytest.raises(ModelRetry, match="fix your input"): # budget exhausted -> propagates
boom.invoke({})
assert boom.invoke({}) == "fix your input and try again" # reset -> fed back again

def test_model_retry_count_resets_after_success(self):
# The retry budget resets on a successful call: fail (fed back), succeed
# (reset), fail again (fed back rather than immediately propagating).
class FlakyToolset(FakeToolset):
def __init__(self) -> None:
super().__init__()
self._outcomes = iter([ModelRetry("retry 1"), "ok", ModelRetry("retry 2")])

async def call_tool(self, name, tool_args, ctx, tool) -> Any:
if name == "boom":
outcome = next(self._outcomes)
if isinstance(outcome, ModelRetry):
raise outcome
return outcome
return await super().call_tool(name, tool_args, ctx, tool)

boom = {t.name: t for t in airflow_toolset_to_langchain_tools(FlakyToolset())}["boom"]

assert boom.invoke({}) == "retry 1" # count 0 -> 1, returned
assert boom.invoke({}) == "ok" # success, count reset to 0
assert boom.invoke({}) == "retry 2" # count 0 -> 1 again, returned (not propagated)

def test_model_retry_returned_as_tool_output_async(self):
boom = {t.name: t for t in airflow_toolset_to_langchain_tools(FakeToolset())}["boom"]

Expand Down
Loading