Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions airflow-core/src/airflow/models/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from airflow.exceptions import AirflowException, AirflowNotFoundException
from airflow.models.base import ID_LEN, Base
from airflow.models.crypto import get_fernet
from airflow.sdk.exceptions import AirflowSecretsBackendAccessDenied
from airflow.utils.helpers import prune_dict
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import NEW_SESSION, provide_session
Expand Down Expand Up @@ -555,6 +556,9 @@ def get_connection_from_secrets(cls, conn_id: str, team_name: str | None = None)
if conn:
SecretCache.save_connection_uri(conn_id, conn.get_uri(), team_name=team_name)
return conn
except AirflowSecretsBackendAccessDenied:
# Authoritative deny — must NOT fall through to a less-restrictive backend.
raise
except Exception:
log.debug(
"Unable to retrieve connection from secrets backend (%s). "
Expand Down
4 changes: 4 additions & 0 deletions airflow-core/src/airflow/models/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from airflow.configuration import conf, ensure_secrets_loaded
from airflow.models.base import ID_LEN, Base
from airflow.models.crypto import get_fernet
from airflow.sdk.exceptions import AirflowSecretsBackendAccessDenied
from airflow.secrets.metastore import MetastoreBackend
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import NEW_SESSION, create_session, provide_session
Expand Down Expand Up @@ -498,6 +499,9 @@ def get_variable_from_secrets(key: str, team_name: str | None = None) -> str | N
var_val = secrets_backend.get_variable(key=key, team_name=team_name)
if var_val is not None:
break
except AirflowSecretsBackendAccessDenied:
# Authoritative deny — must NOT fall through to a less-restrictive backend.
raise
except Exception:
log.exception(
"Unable to retrieve variable from secrets backend (%s). "
Expand Down
28 changes: 28 additions & 0 deletions task-sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,21 @@ def get(self, conn_id: str) -> ConnectionResponse | ErrorResponse:
status_code=e.response.status_code,
)
return ErrorResponse(error=ErrorType.CONNECTION_NOT_FOUND, detail={"conn_id": conn_id})
if e.response.status_code in (HTTPStatus.UNAUTHORIZED, HTTPStatus.FORBIDDEN):
# Surface authz failures as a distinct ErrorType so the
# ExecutionAPISecretsBackend can refuse to fall back to a
# less-restrictive backend (e.g. env vars). 401/403 must
# not be conflated with "not found".
log.debug(
"Connection access denied",
conn_id=conn_id,
detail=e.detail,
status_code=e.response.status_code,
)
return ErrorResponse(
error=ErrorType.PERMISSION_DENIED,
detail={"conn_id": conn_id, "status_code": e.response.status_code},
)
raise
return ConnectionResponse.model_validate_json(resp.read())

Expand All @@ -483,6 +498,19 @@ def get(self, key: str) -> VariableResponse | ErrorResponse:
status_code=e.response.status_code,
)
return ErrorResponse(error=ErrorType.VARIABLE_NOT_FOUND, detail={"key": key})
if e.response.status_code in (HTTPStatus.UNAUTHORIZED, HTTPStatus.FORBIDDEN):
# See ConnectionOperations.get() above for rationale —
# authz failures must not be conflated with "not found".
log.debug(
"Variable access denied",
key=key,
detail=e.detail,
status_code=e.response.status_code,
)
return ErrorResponse(
error=ErrorType.PERMISSION_DENIED,
detail={"key": key, "status_code": e.response.status_code},
)
raise
return VariableResponse.model_validate_json(resp.read())

Expand Down
16 changes: 16 additions & 0 deletions task-sdk/src/airflow/sdk/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,17 @@ class AirflowNotFoundException(AirflowException):
status_code = HTTPStatus.NOT_FOUND


class AirflowSecretsBackendAccessDenied(PermissionError):
"""
Authoritative deny from a secrets backend; dispatcher must NOT fall through.

Distinct from a generic ``PermissionError`` (e.g. an incidental filesystem
``OSError``-family raise from inside an unrelated backend) so the
secrets-backend dispatcher loops can re-raise only this signal and keep
treating other exceptions as "try the next backend".
"""


class AirflowDagCycleException(AirflowException):
"""Raise when there is a cycle in Dag definition."""

Expand All @@ -83,6 +94,11 @@ class ErrorType(enum.Enum):
TASK_STATE_NOT_FOUND = "TASK_STATE_NOT_FOUND"
ASSET_STATE_NOT_FOUND = "ASSET_STATE_NOT_FOUND"
DAGRUN_ALREADY_EXISTS = "DAGRUN_ALREADY_EXISTS"
# Distinct from API_SERVER_ERROR: signals an explicit 401/403 from the
# Execution API. Callers like ExecutionAPISecretsBackend treat this as
# a deny rather than a "not found" so the secrets-backend dispatcher
# does NOT fall through to a less-restrictive backend (e.g. env vars).
PERMISSION_DENIED = "PERMISSION_DENIED"
GENERIC_ERROR = "GENERIC_ERROR"
API_SERVER_ERROR = "API_SERVER_ERROR"

Expand Down
16 changes: 15 additions & 1 deletion task-sdk/src/airflow/sdk/execution_time/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,12 @@
AssetUriRef,
BaseAssetUniqueKey,
)
from airflow.sdk.exceptions import AirflowNotFoundException, AirflowRuntimeError, ErrorType
from airflow.sdk.exceptions import (
AirflowNotFoundException,
AirflowRuntimeError,
AirflowSecretsBackendAccessDenied,
ErrorType,
)
from airflow.sdk.log import mask_secret

if TYPE_CHECKING:
Expand Down Expand Up @@ -168,6 +173,9 @@ def _get_connection(conn_id: str) -> Connection:
SecretCache.save_connection_uri(conn_id, conn.get_uri())
_mask_connection_secrets(conn)
return conn
except AirflowSecretsBackendAccessDenied:
# Authoritative deny — must NOT fall through to a less-restrictive backend.
raise
except Exception:
log.debug(
"Unable to retrieve connection from secrets backend (%s). "
Expand Down Expand Up @@ -215,6 +223,9 @@ async def _async_get_connection(conn_id: str) -> Connection:
SecretCache.save_connection_uri(conn_id, conn.get_uri())
_mask_connection_secrets(conn)
return conn
except AirflowSecretsBackendAccessDenied:
# Authoritative deny — must NOT fall through to a less-restrictive backend.
raise
except Exception:
# If one backend fails, try the next one
log.debug(
Expand Down Expand Up @@ -262,6 +273,9 @@ def _get_variable(key: str, deserialize_json: bool) -> Any:
if isinstance(var_val, str):
mask_secret(var_val, key)
return var_val
except AirflowSecretsBackendAccessDenied:
# Authoritative deny — must NOT fall through to a less-restrictive backend.
raise
except Exception:
log.exception(
"Unable to retrieve variable from secrets backend (%s). Checking subsequent secrets backend.",
Expand Down
71 changes: 59 additions & 12 deletions task-sdk/src/airflow/sdk/execution_time/secrets/execution_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from typing import TYPE_CHECKING

from airflow.sdk.bases.secrets_backend import BaseSecretsBackend
from airflow.sdk.exceptions import AirflowSecretsBackendAccessDenied

if TYPE_CHECKING:
from airflow.sdk import Connection
Expand All @@ -43,6 +44,27 @@ def get_conn_value(self, conn_id: str, team_name: str | None = None) -> str | No
"""
raise NotImplementedError("Use get_connection instead")

def _raise_if_authz_denied(self, msg, *, resource: str, key: str) -> None:
Comment thread
potiuk marked this conversation as resolved.
"""
Raise on an explicit deny response from the Execution API.

Returning None on a 401/403 would let the secrets-backend dispatcher
fall through to a less-restrictive backend (e.g. EnvironmentVariablesBackend
which performs no authorization checks). The Execution API explicitly
denied this request — we must not silently route around that decision.
Other ErrorResponse types (NOT_FOUND, transient API_SERVER_ERROR,
GENERIC_ERROR) keep the existing fallthrough behaviour so the
not-found-here path remains usable.
"""
from airflow.sdk.exceptions import ErrorType
from airflow.sdk.execution_time.comms import ErrorResponse

if isinstance(msg, ErrorResponse) and msg.error == ErrorType.PERMISSION_DENIED:
Comment thread
potiuk marked this conversation as resolved.
raise AirflowSecretsBackendAccessDenied(
f"Access denied for {resource} {key!r} by Execution API; refusing to fall back "
"to a less-restrictive secrets backend."
)

def get_connection(self, conn_id: str, team_name: str | None = None) -> Connection | None: # type: ignore[override]
"""
Return connection object by routing through SUPERVISOR_COMMS.
Expand All @@ -51,6 +73,9 @@ def get_connection(self, conn_id: str, team_name: str | None = None) -> Connecti
:param team_name: Name of the team associated to the task trying to access the connection.
Unused here because the team name is inferred from the task ID provided in the execution API JWT token.
:return: Connection object or None if not found
:raises AirflowSecretsBackendAccessDenied: when the Execution API explicitly denies access
(401/403). Subclasses ``PermissionError``. The secrets-backend dispatcher must not fall
through to an unauthenticated backend in that case.
"""
from airflow.sdk.execution_time.comms import ErrorResponse, GetConnection
from airflow.sdk.execution_time.context import _process_connection_result_conn
Expand All @@ -59,15 +84,20 @@ def get_connection(self, conn_id: str, team_name: str | None = None) -> Connecti
try:
msg = SUPERVISOR_COMMS.send(GetConnection(conn_id=conn_id))

self._raise_if_authz_denied(msg, resource="connection", key=conn_id)

if isinstance(msg, ErrorResponse):
# Connection not found or error occurred
# Connection not found or transient error — allow fallback.
return None

# Convert ExecutionAPI response to SDK Connection
return _process_connection_result_conn(msg)
except AirflowSecretsBackendAccessDenied:
# Re-raise so the dispatcher does NOT fall through.
raise
except Exception:
# If SUPERVISOR_COMMS fails for any reason, return None
# to allow fallback to other backends
# If SUPERVISOR_COMMS fails for any non-authz reason, return None
# to allow fallback to other backends.
return None

def get_variable(self, key: str, team_name: str | None = None) -> str | None:
Expand All @@ -78,24 +108,31 @@ def get_variable(self, key: str, team_name: str | None = None) -> str | None:
:param team_name: Name of the team associated to the task trying to access the variable.
Unused here because the team name is inferred from the task ID provided in the execution API JWT token.
:return: Variable value or None if not found
:raises AirflowSecretsBackendAccessDenied: when the Execution API explicitly denies access
(401/403). Subclasses ``PermissionError``. The secrets-backend dispatcher must not fall
through to an unauthenticated backend in that case.
"""
from airflow.sdk.execution_time.comms import ErrorResponse, GetVariable, VariableResult
from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS

try:
msg = SUPERVISOR_COMMS.send(GetVariable(key=key))

self._raise_if_authz_denied(msg, resource="variable", key=key)

if isinstance(msg, ErrorResponse):
# Variable not found or error occurred
# Variable not found or transient error — allow fallback.
return None

# Extract value from VariableResult
if isinstance(msg, VariableResult):
return msg.value # Already a string | None
return None
except AirflowSecretsBackendAccessDenied:
raise
except Exception:
# If SUPERVISOR_COMMS fails for any reason, return None
# to allow fallback to other backends
# If SUPERVISOR_COMMS fails for any non-authz reason, return None
# to allow fallback to other backends.
return None

async def aget_connection(self, conn_id: str) -> Connection | None: # type: ignore[override]
Expand All @@ -104,6 +141,7 @@ async def aget_connection(self, conn_id: str) -> Connection | None: # type: ign

:param conn_id: connection id
:return: Connection object or None if not found
:raises AirflowSecretsBackendAccessDenied: see :meth:`get_connection`.
"""
from airflow.sdk.execution_time.comms import ErrorResponse, GetConnection
from airflow.sdk.execution_time.context import _process_connection_result_conn
Expand All @@ -112,15 +150,19 @@ async def aget_connection(self, conn_id: str) -> Connection | None: # type: ign
try:
msg = await SUPERVISOR_COMMS.asend(GetConnection(conn_id=conn_id))

self._raise_if_authz_denied(msg, resource="connection", key=conn_id)

if isinstance(msg, ErrorResponse):
# Connection not found or error occurred
# Connection not found or transient error — allow fallback.
return None

# Convert ExecutionAPI response to SDK Connection
return _process_connection_result_conn(msg)
except AirflowSecretsBackendAccessDenied:
raise
except Exception:
# If SUPERVISOR_COMMS fails for any reason, return None
# to allow fallback to other backends
# If SUPERVISOR_COMMS fails for any non-authz reason, return None
# to allow fallback to other backends.
return None

async def aget_variable(self, key: str) -> str | None:
Expand All @@ -129,22 +171,27 @@ async def aget_variable(self, key: str) -> str | None:

:param key: Variable key
:return: Variable value or None if not found
:raises AirflowSecretsBackendAccessDenied: see :meth:`get_variable`.
"""
from airflow.sdk.execution_time.comms import ErrorResponse, GetVariable, VariableResult
from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS

try:
msg = await SUPERVISOR_COMMS.asend(GetVariable(key=key))

self._raise_if_authz_denied(msg, resource="variable", key=key)

if isinstance(msg, ErrorResponse):
# Variable not found or error occurred
# Variable not found or transient error — allow fallback.
return None

# Extract value from VariableResult
if isinstance(msg, VariableResult):
return msg.value # Already a string | None
return None
except AirflowSecretsBackendAccessDenied:
raise
except Exception:
# If SUPERVISOR_COMMS fails for any reason, return None
# to allow fallback to other backends
# If SUPERVISOR_COMMS fails for any non-authz reason, return None
# to allow fallback to other backends.
return None
30 changes: 30 additions & 0 deletions task-sdk/tests/task_sdk/api/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,6 +841,23 @@ def handle_request(request: httpx.Request) -> httpx.Response:
key="test_key",
)

@pytest.mark.parametrize("status_code", [401, 403])
def test_variable_get_authz_returns_permission_denied(self, status_code):
"""401/403 from the API server is reported as PERMISSION_DENIED, not raised.

The ExecutionAPISecretsBackend uses this distinction to refuse fallback
to a less-restrictive backend on an explicit deny.
"""

def handle_request(request: httpx.Request) -> httpx.Response:
return httpx.Response(status_code=status_code, json={"detail": "Forbidden"})

client = make_client(transport=httpx.MockTransport(handle_request))
resp = client.variables.get(key="denied_var")
assert isinstance(resp, ErrorResponse)
assert resp.error == ErrorType.PERMISSION_DENIED
assert resp.detail == {"key": "denied_var", "status_code": status_code}

def test_variable_set_success(self):
# Simulate a successful response from the server when putting a variable
def handle_request(request: httpx.Request) -> httpx.Response:
Expand Down Expand Up @@ -1121,6 +1138,19 @@ def handle_request(request: httpx.Request) -> httpx.Response:
assert isinstance(result, ErrorResponse)
assert result.error == ErrorType.CONNECTION_NOT_FOUND

@pytest.mark.parametrize("status_code", [401, 403])
def test_connection_get_authz_returns_permission_denied(self, status_code):
"""401/403 from the API server is reported as PERMISSION_DENIED, not raised."""

def handle_request(request: httpx.Request) -> httpx.Response:
return httpx.Response(status_code=status_code, json={"detail": "Forbidden"})

client = make_client(transport=httpx.MockTransport(handle_request))
result = client.connections.get(conn_id="denied_conn")
assert isinstance(result, ErrorResponse)
assert result.error == ErrorType.PERMISSION_DENIED
assert result.detail == {"conn_id": "denied_conn", "status_code": status_code}


class TestAssetEventOperations:
@pytest.mark.parametrize(
Expand Down
Loading
Loading