Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 69 additions & 1 deletion providers/common/ai/docs/toolsets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ Airflow's 350+ provider hooks already have typed methods, rich docstrings,
and managed credentials. Toolsets expose them as pydantic-ai tools so that
LLM agents can call them during multi-turn reasoning.

Two toolsets are included:
Three toolsets are included:

- :class:`~airflow.providers.common.ai.toolsets.hook.HookToolset` — generic
adapter for any Airflow Hook.
Expand Down Expand Up @@ -121,6 +121,69 @@ Parameters
Default ``False`` — only SELECT-family statements are permitted.
- ``max_rows``: Maximum rows returned from the ``query`` tool. Default ``50``.

``DataFusionToolset``
---------------------

Curated toolset wrapping
:class:`~airflow.providers.common.sql.datafusion.engine.DataFusionEngine`
with three tools — ``list_tables``, ``get_schema``, and ``query`` — for
querying files on object stores (S3, local filesystem, Iceberg) via Apache DataFusion.

.. list-table::
:header-rows: 1
:widths: 20 50

* - Tool
- Description
* - ``list_tables``
- Lists registered table names
* - ``get_schema``
- Returns column names and types for a table (Arrow schema)
* - ``query``
- Executes a SQL query and returns rows as JSON

Each :class:`~airflow.providers.common.sql.config.DataSourceConfig` entry
registers a table backed by Parquet, CSV, Avro, or Iceberg data. Multiple
configs can be registered so that SQL queries can join across tables.

.. code-block:: python

from airflow.providers.common.ai.toolsets.datafusion import DataFusionToolset
from airflow.providers.common.sql.config import DataSourceConfig

toolset = DataFusionToolset(
datasource_configs=[
DataSourceConfig(
conn_id="aws_default",
table_name="sales",
uri="s3://my-bucket/data/sales/",
format="parquet",
),
DataSourceConfig(
conn_id="aws_default",
table_name="returns",
uri="s3://my-bucket/data/returns/",
format="csv",
),
],
max_rows=100,
)

The ``DataFusionEngine`` is created lazily on the first tool call. This
toolset requires the ``datafusion`` extra of
``apache-airflow-providers-common-sql``.

Parameters
^^^^^^^^^^

- ``datasource_configs``: One or more
:class:`~airflow.providers.common.sql.config.DataSourceConfig` entries.
Requires ``apache-airflow-providers-common-sql[datafusion]``.
- ``allow_writes``: Allow data-modifying SQL (CREATE TABLE, CREATE VIEW,
INSERT INTO, etc.). Default ``False`` — only SELECT-family statements are
permitted. DataFusion on object stores is mostly read-only, but it does
support DDL for in-memory tables; this guard blocks those by default.
- ``max_rows``: Maximum rows returned from the ``query`` tool. Default ``50``.

Security
--------
Expand Down Expand Up @@ -155,6 +218,11 @@ No single layer is sufficient — they work together.
``validate_sql()`` and rejects INSERT, UPDATE, DELETE, DROP, etc.
- Does not prevent the agent from reading sensitive data that the
database user has SELECT access to.
* - **DataFusionToolset: read-only by default**
- ``allow_writes=False`` (default) validates every SQL query through
``validate_sql()`` and rejects CREATE TABLE, CREATE VIEW, INSERT
INTO, and other non-SELECT statements.
- Does not prevent the agent from reading any registered data source.
* - **SQLToolset: allowed_tables**
- Restricts which tables appear in ``list_tables`` and ``get_schema``
responses, limiting the agent's knowledge of the schema.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Curated SQL toolset wrapping DataFusionEngine for agentic object-store workflows."""

from __future__ import annotations

import json
import logging
from typing import TYPE_CHECKING, Any

try:
from airflow.providers.common.ai.utils.sql_validation import SQLSafetyError, validate_sql as _validate_sql
from airflow.providers.common.sql.datafusion.engine import DataFusionEngine
from airflow.providers.common.sql.datafusion.exceptions import QueryExecutionException
except ImportError as e:
from airflow.providers.common.compat.sdk import AirflowOptionalProviderFeatureException

raise AirflowOptionalProviderFeatureException(e)

from pydantic_ai.tools import ToolDefinition
from pydantic_ai.toolsets.abstract import AbstractToolset, ToolsetTool
from pydantic_core import SchemaValidator, core_schema

if TYPE_CHECKING:
from pydantic_ai._run_context import RunContext

from airflow.providers.common.sql.config import DataSourceConfig

log = logging.getLogger(__name__)

_PASSTHROUGH_VALIDATOR = SchemaValidator(core_schema.any_schema())

# JSON Schemas for the three DataFusion tools.
_LIST_TABLES_SCHEMA: dict[str, Any] = {
"type": "object",
"properties": {},
}

_GET_SCHEMA_SCHEMA: dict[str, Any] = {
"type": "object",
"properties": {
"table_name": {"type": "string", "description": "Name of the table to inspect."},
},
"required": ["table_name"],
}

_QUERY_SCHEMA: dict[str, Any] = {
"type": "object",
"properties": {
"sql": {"type": "string", "description": "SQL query to execute."},
},
"required": ["sql"],
}


class DataFusionToolset(AbstractToolset[Any]):
"""
Curated toolset that gives an LLM agent SQL access to object-storage data via Apache DataFusion.

Provides three tools — ``list_tables``, ``get_schema``, and ``query`` —
backed by
:class:`~airflow.providers.common.sql.datafusion.engine.DataFusionEngine`.

Each :class:`~airflow.providers.common.sql.config.DataSourceConfig` entry
registers a table backed by Parquet, CSV, Avro, or Iceberg data on S3 or
local storage. Multiple configs can be registered so that SQL queries can
join across tables.

Requires the ``datafusion`` extra of ``apache-airflow-providers-common-sql``.

:param datasource_configs: One or more DataFusion data-source configurations.
:param allow_writes: Allow data-modifying SQL (CREATE TABLE, CREATE VIEW,
INSERT INTO, etc.). Default ``False`` — only SELECT-family statements
are permitted.
:param max_rows: Maximum number of rows returned from the ``query`` tool.
Default ``50``.
"""

def __init__(
self,
datasource_configs: list[DataSourceConfig],
*,
allow_writes: bool = False,
max_rows: int = 50,
) -> None:
if not datasource_configs:
raise ValueError("datasource_configs must contain at least one DataSourceConfig")
self._datasource_configs = datasource_configs
self._allow_writes = allow_writes
self._max_rows = max_rows
self._engine: DataFusionEngine | None = None

@property
def id(self) -> str:
suffix = "_".join(config.table_name.replace("-", "_") for config in self._datasource_configs)
return f"sql_datafusion_{suffix}"

def _get_engine(self) -> DataFusionEngine:
"""Lazily create and configure a DataFusionEngine from *datasource_configs*."""
if self._engine is None:
engine = DataFusionEngine()
for config in self._datasource_configs:
engine.register_datasource(config)
self._engine = engine
return self._engine

async def get_tools(self, ctx: RunContext[Any]) -> dict[str, ToolsetTool[Any]]:
tools: dict[str, ToolsetTool[Any]] = {}

for name, description, schema in (
("list_tables", "List available table names.", _LIST_TABLES_SCHEMA),
("get_schema", "Get column names and types for a table.", _GET_SCHEMA_SCHEMA),
("query", "Execute a SQL query and return rows as JSON.", _QUERY_SCHEMA),
):
tool_def = ToolDefinition(
name=name,
description=description,
parameters_json_schema=schema,
sequential=True,
)
tools[name] = ToolsetTool(
toolset=self,
tool_def=tool_def,
max_retries=1,
args_validator=_PASSTHROUGH_VALIDATOR,
)
return tools

async def call_tool(
self,
name: str,
tool_args: dict[str, Any],
ctx: RunContext[Any],
tool: ToolsetTool[Any],
) -> Any:
if name == "list_tables":
return self._list_tables()
if name == "get_schema":
return self._get_schema(tool_args["table_name"])
if name == "query":
return self._query(tool_args["sql"])
raise ValueError(f"Unknown tool: {name!r}")

def _list_tables(self) -> str:
try:
engine = self._get_engine()
tables: list[str] = list(engine.session_context.catalog().schema().table_names())
return json.dumps(tables)
except Exception as ex:
log.warning("list_tables failed: %s", ex)
return json.dumps({"error": str(ex)})

def _get_schema(self, table_name: str) -> str:
engine = self._get_engine()
# session_context lookup is required here instead of engine.registered_tables,
# because registered_tables only tracks tables registered via datasource config.
# When allow_writes is enabled, the agent may create temporary in-memory tables
# that would not be captured there.
if not engine.session_context.table_exist(table_name):
return json.dumps({"error": f"Table {table_name!r} is not available"})
# Intentionally using session_context instead of engine.get_schema() —
# the latter returns a pre-formatted string intended for other operators,
# not a JSON-compatible format.
# TODO: refactor engine.get_schema() to return JSON and update this accordingly
table = engine.session_context.table(table_name)
columns = [{"name": f.name, "type": str(f.type)} for f in table.schema()]
return json.dumps(columns)

def _query(self, sql: str) -> str:
try:
if not self._allow_writes:
_validate_sql(sql)

engine = self._get_engine()
pydict = engine.execute_query(sql)
col_names = list(pydict.keys())
num_rows = len(next(iter(pydict.values()), []))

result: list[dict[str, Any]] = [
{col: pydict[col][i] for col in col_names} for i in range(min(num_rows, self._max_rows))
]

truncated = num_rows > self._max_rows
output: dict[str, Any] = {"rows": result, "count": num_rows}
if truncated:
output["truncated"] = True
output["max_rows"] = self._max_rows
return json.dumps(output, default=str)
except SQLSafetyError as ex:
log.warning("query failed SQL safety validation: %s", ex)
raise
except QueryExecutionException as ex:
return json.dumps({"error": str(ex), "query": sql})
Loading
Loading