diff --git a/airflow-core/src/airflow/models/connection.py b/airflow-core/src/airflow/models/connection.py index e5480d5c51f21..55a61de022ec1 100644 --- a/airflow-core/src/airflow/models/connection.py +++ b/airflow-core/src/airflow/models/connection.py @@ -35,6 +35,7 @@ from airflow.exceptions import AirflowException, AirflowNotFoundException from airflow.models.base import ID_LEN, Base from airflow.models.crypto import get_fernet +from airflow.sdk.exceptions import AirflowSecretsBackendAccessDenied from airflow.utils.helpers import prune_dict from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.session import NEW_SESSION, provide_session @@ -555,6 +556,9 @@ def get_connection_from_secrets(cls, conn_id: str, team_name: str | None = None) if conn: SecretCache.save_connection_uri(conn_id, conn.get_uri(), team_name=team_name) return conn + except AirflowSecretsBackendAccessDenied: + # Authoritative deny — must NOT fall through to a less-restrictive backend. + raise except Exception: log.debug( "Unable to retrieve connection from secrets backend (%s). " diff --git a/airflow-core/src/airflow/models/variable.py b/airflow-core/src/airflow/models/variable.py index 9bc2ba9be9649..0d543f334be42 100644 --- a/airflow-core/src/airflow/models/variable.py +++ b/airflow-core/src/airflow/models/variable.py @@ -32,6 +32,7 @@ from airflow.configuration import conf, ensure_secrets_loaded from airflow.models.base import ID_LEN, Base from airflow.models.crypto import get_fernet +from airflow.sdk.exceptions import AirflowSecretsBackendAccessDenied from airflow.secrets.metastore import MetastoreBackend from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.session import NEW_SESSION, create_session, provide_session @@ -498,6 +499,9 @@ def get_variable_from_secrets(key: str, team_name: str | None = None) -> str | N var_val = secrets_backend.get_variable(key=key, team_name=team_name) if var_val is not None: break + except AirflowSecretsBackendAccessDenied: + # Authoritative deny — must NOT fall through to a less-restrictive backend. + raise except Exception: log.exception( "Unable to retrieve variable from secrets backend (%s). " diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index 269978ac9dd1d..ac27d2bc30c05 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -460,6 +460,21 @@ def get(self, conn_id: str) -> ConnectionResponse | ErrorResponse: status_code=e.response.status_code, ) return ErrorResponse(error=ErrorType.CONNECTION_NOT_FOUND, detail={"conn_id": conn_id}) + if e.response.status_code in (HTTPStatus.UNAUTHORIZED, HTTPStatus.FORBIDDEN): + # Surface authz failures as a distinct ErrorType so the + # ExecutionAPISecretsBackend can refuse to fall back to a + # less-restrictive backend (e.g. env vars). 401/403 must + # not be conflated with "not found". + log.debug( + "Connection access denied", + conn_id=conn_id, + detail=e.detail, + status_code=e.response.status_code, + ) + return ErrorResponse( + error=ErrorType.PERMISSION_DENIED, + detail={"conn_id": conn_id, "status_code": e.response.status_code}, + ) raise return ConnectionResponse.model_validate_json(resp.read()) @@ -483,6 +498,19 @@ def get(self, key: str) -> VariableResponse | ErrorResponse: status_code=e.response.status_code, ) return ErrorResponse(error=ErrorType.VARIABLE_NOT_FOUND, detail={"key": key}) + if e.response.status_code in (HTTPStatus.UNAUTHORIZED, HTTPStatus.FORBIDDEN): + # See ConnectionOperations.get() above for rationale — + # authz failures must not be conflated with "not found". + log.debug( + "Variable access denied", + key=key, + detail=e.detail, + status_code=e.response.status_code, + ) + return ErrorResponse( + error=ErrorType.PERMISSION_DENIED, + detail={"key": key, "status_code": e.response.status_code}, + ) raise return VariableResponse.model_validate_json(resp.read()) diff --git a/task-sdk/src/airflow/sdk/exceptions.py b/task-sdk/src/airflow/sdk/exceptions.py index b0ff82be293e7..04bf0a8c56ac8 100644 --- a/task-sdk/src/airflow/sdk/exceptions.py +++ b/task-sdk/src/airflow/sdk/exceptions.py @@ -57,6 +57,17 @@ class AirflowNotFoundException(AirflowException): status_code = HTTPStatus.NOT_FOUND +class AirflowSecretsBackendAccessDenied(PermissionError): + """ + Authoritative deny from a secrets backend; dispatcher must NOT fall through. + + Distinct from a generic ``PermissionError`` (e.g. an incidental filesystem + ``OSError``-family raise from inside an unrelated backend) so the + secrets-backend dispatcher loops can re-raise only this signal and keep + treating other exceptions as "try the next backend". + """ + + class AirflowDagCycleException(AirflowException): """Raise when there is a cycle in Dag definition.""" @@ -83,6 +94,11 @@ class ErrorType(enum.Enum): TASK_STATE_NOT_FOUND = "TASK_STATE_NOT_FOUND" ASSET_STATE_NOT_FOUND = "ASSET_STATE_NOT_FOUND" DAGRUN_ALREADY_EXISTS = "DAGRUN_ALREADY_EXISTS" + # Distinct from API_SERVER_ERROR: signals an explicit 401/403 from the + # Execution API. Callers like ExecutionAPISecretsBackend treat this as + # a deny rather than a "not found" so the secrets-backend dispatcher + # does NOT fall through to a less-restrictive backend (e.g. env vars). + PERMISSION_DENIED = "PERMISSION_DENIED" GENERIC_ERROR = "GENERIC_ERROR" API_SERVER_ERROR = "API_SERVER_ERROR" diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py b/task-sdk/src/airflow/sdk/execution_time/context.py index cdd258039898e..bc6ec0ead2b9f 100644 --- a/task-sdk/src/airflow/sdk/execution_time/context.py +++ b/task-sdk/src/airflow/sdk/execution_time/context.py @@ -42,7 +42,12 @@ AssetUriRef, BaseAssetUniqueKey, ) -from airflow.sdk.exceptions import AirflowNotFoundException, AirflowRuntimeError, ErrorType +from airflow.sdk.exceptions import ( + AirflowNotFoundException, + AirflowRuntimeError, + AirflowSecretsBackendAccessDenied, + ErrorType, +) from airflow.sdk.log import mask_secret if TYPE_CHECKING: @@ -168,6 +173,9 @@ def _get_connection(conn_id: str) -> Connection: SecretCache.save_connection_uri(conn_id, conn.get_uri()) _mask_connection_secrets(conn) return conn + except AirflowSecretsBackendAccessDenied: + # Authoritative deny — must NOT fall through to a less-restrictive backend. + raise except Exception: log.debug( "Unable to retrieve connection from secrets backend (%s). " @@ -215,6 +223,9 @@ async def _async_get_connection(conn_id: str) -> Connection: SecretCache.save_connection_uri(conn_id, conn.get_uri()) _mask_connection_secrets(conn) return conn + except AirflowSecretsBackendAccessDenied: + # Authoritative deny — must NOT fall through to a less-restrictive backend. + raise except Exception: # If one backend fails, try the next one log.debug( @@ -262,6 +273,9 @@ def _get_variable(key: str, deserialize_json: bool) -> Any: if isinstance(var_val, str): mask_secret(var_val, key) return var_val + except AirflowSecretsBackendAccessDenied: + # Authoritative deny — must NOT fall through to a less-restrictive backend. + raise except Exception: log.exception( "Unable to retrieve variable from secrets backend (%s). Checking subsequent secrets backend.", diff --git a/task-sdk/src/airflow/sdk/execution_time/secrets/execution_api.py b/task-sdk/src/airflow/sdk/execution_time/secrets/execution_api.py index b8b240b498837..57bffd12a1620 100644 --- a/task-sdk/src/airflow/sdk/execution_time/secrets/execution_api.py +++ b/task-sdk/src/airflow/sdk/execution_time/secrets/execution_api.py @@ -21,6 +21,7 @@ from typing import TYPE_CHECKING from airflow.sdk.bases.secrets_backend import BaseSecretsBackend +from airflow.sdk.exceptions import AirflowSecretsBackendAccessDenied if TYPE_CHECKING: from airflow.sdk import Connection @@ -43,6 +44,27 @@ def get_conn_value(self, conn_id: str, team_name: str | None = None) -> str | No """ raise NotImplementedError("Use get_connection instead") + def _raise_if_authz_denied(self, msg, *, resource: str, key: str) -> None: + """ + Raise on an explicit deny response from the Execution API. + + Returning None on a 401/403 would let the secrets-backend dispatcher + fall through to a less-restrictive backend (e.g. EnvironmentVariablesBackend + which performs no authorization checks). The Execution API explicitly + denied this request — we must not silently route around that decision. + Other ErrorResponse types (NOT_FOUND, transient API_SERVER_ERROR, + GENERIC_ERROR) keep the existing fallthrough behaviour so the + not-found-here path remains usable. + """ + from airflow.sdk.exceptions import ErrorType + from airflow.sdk.execution_time.comms import ErrorResponse + + if isinstance(msg, ErrorResponse) and msg.error == ErrorType.PERMISSION_DENIED: + raise AirflowSecretsBackendAccessDenied( + f"Access denied for {resource} {key!r} by Execution API; refusing to fall back " + "to a less-restrictive secrets backend." + ) + def get_connection(self, conn_id: str, team_name: str | None = None) -> Connection | None: # type: ignore[override] """ Return connection object by routing through SUPERVISOR_COMMS. @@ -51,6 +73,9 @@ def get_connection(self, conn_id: str, team_name: str | None = None) -> Connecti :param team_name: Name of the team associated to the task trying to access the connection. Unused here because the team name is inferred from the task ID provided in the execution API JWT token. :return: Connection object or None if not found + :raises AirflowSecretsBackendAccessDenied: when the Execution API explicitly denies access + (401/403). Subclasses ``PermissionError``. The secrets-backend dispatcher must not fall + through to an unauthenticated backend in that case. """ from airflow.sdk.execution_time.comms import ErrorResponse, GetConnection from airflow.sdk.execution_time.context import _process_connection_result_conn @@ -59,15 +84,20 @@ def get_connection(self, conn_id: str, team_name: str | None = None) -> Connecti try: msg = SUPERVISOR_COMMS.send(GetConnection(conn_id=conn_id)) + self._raise_if_authz_denied(msg, resource="connection", key=conn_id) + if isinstance(msg, ErrorResponse): - # Connection not found or error occurred + # Connection not found or transient error — allow fallback. return None # Convert ExecutionAPI response to SDK Connection return _process_connection_result_conn(msg) + except AirflowSecretsBackendAccessDenied: + # Re-raise so the dispatcher does NOT fall through. + raise except Exception: - # If SUPERVISOR_COMMS fails for any reason, return None - # to allow fallback to other backends + # If SUPERVISOR_COMMS fails for any non-authz reason, return None + # to allow fallback to other backends. return None def get_variable(self, key: str, team_name: str | None = None) -> str | None: @@ -78,6 +108,9 @@ def get_variable(self, key: str, team_name: str | None = None) -> str | None: :param team_name: Name of the team associated to the task trying to access the variable. Unused here because the team name is inferred from the task ID provided in the execution API JWT token. :return: Variable value or None if not found + :raises AirflowSecretsBackendAccessDenied: when the Execution API explicitly denies access + (401/403). Subclasses ``PermissionError``. The secrets-backend dispatcher must not fall + through to an unauthenticated backend in that case. """ from airflow.sdk.execution_time.comms import ErrorResponse, GetVariable, VariableResult from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS @@ -85,17 +118,21 @@ def get_variable(self, key: str, team_name: str | None = None) -> str | None: try: msg = SUPERVISOR_COMMS.send(GetVariable(key=key)) + self._raise_if_authz_denied(msg, resource="variable", key=key) + if isinstance(msg, ErrorResponse): - # Variable not found or error occurred + # Variable not found or transient error — allow fallback. return None # Extract value from VariableResult if isinstance(msg, VariableResult): return msg.value # Already a string | None return None + except AirflowSecretsBackendAccessDenied: + raise except Exception: - # If SUPERVISOR_COMMS fails for any reason, return None - # to allow fallback to other backends + # If SUPERVISOR_COMMS fails for any non-authz reason, return None + # to allow fallback to other backends. return None async def aget_connection(self, conn_id: str) -> Connection | None: # type: ignore[override] @@ -104,6 +141,7 @@ async def aget_connection(self, conn_id: str) -> Connection | None: # type: ign :param conn_id: connection id :return: Connection object or None if not found + :raises AirflowSecretsBackendAccessDenied: see :meth:`get_connection`. """ from airflow.sdk.execution_time.comms import ErrorResponse, GetConnection from airflow.sdk.execution_time.context import _process_connection_result_conn @@ -112,15 +150,19 @@ async def aget_connection(self, conn_id: str) -> Connection | None: # type: ign try: msg = await SUPERVISOR_COMMS.asend(GetConnection(conn_id=conn_id)) + self._raise_if_authz_denied(msg, resource="connection", key=conn_id) + if isinstance(msg, ErrorResponse): - # Connection not found or error occurred + # Connection not found or transient error — allow fallback. return None # Convert ExecutionAPI response to SDK Connection return _process_connection_result_conn(msg) + except AirflowSecretsBackendAccessDenied: + raise except Exception: - # If SUPERVISOR_COMMS fails for any reason, return None - # to allow fallback to other backends + # If SUPERVISOR_COMMS fails for any non-authz reason, return None + # to allow fallback to other backends. return None async def aget_variable(self, key: str) -> str | None: @@ -129,6 +171,7 @@ async def aget_variable(self, key: str) -> str | None: :param key: Variable key :return: Variable value or None if not found + :raises AirflowSecretsBackendAccessDenied: see :meth:`get_variable`. """ from airflow.sdk.execution_time.comms import ErrorResponse, GetVariable, VariableResult from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS @@ -136,15 +179,19 @@ async def aget_variable(self, key: str) -> str | None: try: msg = await SUPERVISOR_COMMS.asend(GetVariable(key=key)) + self._raise_if_authz_denied(msg, resource="variable", key=key) + if isinstance(msg, ErrorResponse): - # Variable not found or error occurred + # Variable not found or transient error — allow fallback. return None # Extract value from VariableResult if isinstance(msg, VariableResult): return msg.value # Already a string | None return None + except AirflowSecretsBackendAccessDenied: + raise except Exception: - # If SUPERVISOR_COMMS fails for any reason, return None - # to allow fallback to other backends + # If SUPERVISOR_COMMS fails for any non-authz reason, return None + # to allow fallback to other backends. return None diff --git a/task-sdk/tests/task_sdk/api/test_client.py b/task-sdk/tests/task_sdk/api/test_client.py index a179ff08436b2..60adf3821ddaf 100644 --- a/task-sdk/tests/task_sdk/api/test_client.py +++ b/task-sdk/tests/task_sdk/api/test_client.py @@ -841,6 +841,23 @@ def handle_request(request: httpx.Request) -> httpx.Response: key="test_key", ) + @pytest.mark.parametrize("status_code", [401, 403]) + def test_variable_get_authz_returns_permission_denied(self, status_code): + """401/403 from the API server is reported as PERMISSION_DENIED, not raised. + + The ExecutionAPISecretsBackend uses this distinction to refuse fallback + to a less-restrictive backend on an explicit deny. + """ + + def handle_request(request: httpx.Request) -> httpx.Response: + return httpx.Response(status_code=status_code, json={"detail": "Forbidden"}) + + client = make_client(transport=httpx.MockTransport(handle_request)) + resp = client.variables.get(key="denied_var") + assert isinstance(resp, ErrorResponse) + assert resp.error == ErrorType.PERMISSION_DENIED + assert resp.detail == {"key": "denied_var", "status_code": status_code} + def test_variable_set_success(self): # Simulate a successful response from the server when putting a variable def handle_request(request: httpx.Request) -> httpx.Response: @@ -1121,6 +1138,19 @@ def handle_request(request: httpx.Request) -> httpx.Response: assert isinstance(result, ErrorResponse) assert result.error == ErrorType.CONNECTION_NOT_FOUND + @pytest.mark.parametrize("status_code", [401, 403]) + def test_connection_get_authz_returns_permission_denied(self, status_code): + """401/403 from the API server is reported as PERMISSION_DENIED, not raised.""" + + def handle_request(request: httpx.Request) -> httpx.Response: + return httpx.Response(status_code=status_code, json={"detail": "Forbidden"}) + + client = make_client(transport=httpx.MockTransport(handle_request)) + result = client.connections.get(conn_id="denied_conn") + assert isinstance(result, ErrorResponse) + assert result.error == ErrorType.PERMISSION_DENIED + assert result.detail == {"conn_id": "denied_conn", "status_code": status_code} + class TestAssetEventOperations: @pytest.mark.parametrize( diff --git a/task-sdk/tests/task_sdk/execution_time/test_secrets.py b/task-sdk/tests/task_sdk/execution_time/test_secrets.py index bda87a6a64ae0..5c9cb7a22f539 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_secrets.py +++ b/task-sdk/tests/task_sdk/execution_time/test_secrets.py @@ -19,6 +19,9 @@ import pytest +from airflow.sdk.api.datamodels._generated import ConnectionResponse +from airflow.sdk.exceptions import AirflowSecretsBackendAccessDenied, ErrorType +from airflow.sdk.execution_time.comms import ConnectionResult, ErrorResponse, VariableResult from airflow.sdk.execution_time.secrets.execution_api import ExecutionAPISecretsBackend @@ -27,10 +30,6 @@ class TestExecutionAPISecretsBackend: def test_get_connection_via_supervisor_comms(self, mock_supervisor_comms): """Test that connection is retrieved via SUPERVISOR_COMMS.""" - from airflow.sdk.api.datamodels._generated import ConnectionResponse - from airflow.sdk.execution_time.comms import ConnectionResult - - # Mock connection response conn_response = ConnectionResponse( conn_id="test_conn", conn_type="http", @@ -52,10 +51,6 @@ def test_get_connection_via_supervisor_comms(self, mock_supervisor_comms): def test_get_connection_not_found(self, mock_supervisor_comms): """Test that None is returned when connection not found.""" - from airflow.sdk.exceptions import ErrorType - from airflow.sdk.execution_time.comms import ErrorResponse - - # Mock error response error_response = ErrorResponse(error=ErrorType.CONNECTION_NOT_FOUND, detail={"message": "Not found"}) mock_supervisor_comms.send.return_value = error_response @@ -67,9 +62,6 @@ def test_get_connection_not_found(self, mock_supervisor_comms): def test_get_variable_via_supervisor_comms(self, mock_supervisor_comms): """Test that variable is retrieved via SUPERVISOR_COMMS.""" - from airflow.sdk.execution_time.comms import VariableResult - - # Mock variable response var_result = VariableResult(key="test_var", value="test_value") mock_supervisor_comms.send.return_value = var_result @@ -81,10 +73,6 @@ def test_get_variable_via_supervisor_comms(self, mock_supervisor_comms): def test_get_variable_not_found(self, mock_supervisor_comms): """Test that None is returned when variable not found.""" - from airflow.sdk.exceptions import ErrorType - from airflow.sdk.execution_time.comms import ErrorResponse - - # Mock error response error_response = ErrorResponse(error=ErrorType.VARIABLE_NOT_FOUND, detail={"message": "Not found"}) mock_supervisor_comms.send.return_value = error_response @@ -120,6 +108,150 @@ def test_get_conn_value_not_implemented(self): with pytest.raises(NotImplementedError, match="Use get_connection instead"): backend.get_conn_value("test_conn") + def test_get_connection_raises_on_permission_denied(self, mock_supervisor_comms): + """An explicit deny from the Execution API must raise, not fall through. + + Returning None on a 401/403 would let the secrets-backend dispatcher + fall through to a less-restrictive backend (e.g. EnvironmentVariablesBackend). + """ + mock_supervisor_comms.send.return_value = ErrorResponse( + error=ErrorType.PERMISSION_DENIED, + detail={"conn_id": "denied_conn", "status_code": 403}, + ) + backend = ExecutionAPISecretsBackend() + with pytest.raises(AirflowSecretsBackendAccessDenied, match="connection 'denied_conn'"): + backend.get_connection("denied_conn") + + def test_get_variable_raises_on_permission_denied(self, mock_supervisor_comms): + """An explicit deny from the Execution API must raise for variables too.""" + mock_supervisor_comms.send.return_value = ErrorResponse( + error=ErrorType.PERMISSION_DENIED, + detail={"key": "denied_var", "status_code": 403}, + ) + backend = ExecutionAPISecretsBackend() + with pytest.raises(AirflowSecretsBackendAccessDenied, match="variable 'denied_var'"): + backend.get_variable("denied_var") + + @pytest.mark.asyncio + async def test_aget_connection_raises_on_permission_denied(self, mock_supervisor_comms): + """Async variant must also raise on PERMISSION_DENIED.""" + + async def asend(*_args, **_kwargs): + return ErrorResponse( + error=ErrorType.PERMISSION_DENIED, + detail={"conn_id": "denied_conn", "status_code": 403}, + ) + + mock_supervisor_comms.asend = asend + backend = ExecutionAPISecretsBackend() + with pytest.raises(AirflowSecretsBackendAccessDenied, match="connection 'denied_conn'"): + await backend.aget_connection("denied_conn") + + @pytest.mark.asyncio + async def test_aget_variable_raises_on_permission_denied(self, mock_supervisor_comms): + """Async variant for variables must also raise on PERMISSION_DENIED.""" + + async def asend(*_args, **_kwargs): + return ErrorResponse( + error=ErrorType.PERMISSION_DENIED, + detail={"key": "denied_var", "status_code": 403}, + ) + + mock_supervisor_comms.asend = asend + backend = ExecutionAPISecretsBackend() + with pytest.raises(AirflowSecretsBackendAccessDenied, match="variable 'denied_var'"): + await backend.aget_variable("denied_var") + + +class TestDispatcherRefusesFallbackOnDeny: + """End-to-end: the secrets-backend dispatcher must NOT fall through on an authoritative deny. + + A backend-level raise is not enough on its own — the outer ``except Exception:`` in + ``context._get_connection`` / ``_get_variable`` / ``_async_get_connection`` previously + swallowed ``PermissionError`` and silently called the next (less-restrictive) backend. + These tests pin the dispatcher behaviour by inserting a spy backend AFTER + ``ExecutionAPISecretsBackend`` and asserting it is never called once the first backend + raises ``AirflowSecretsBackendAccessDenied``. + """ + + def test_get_connection_does_not_fall_through_after_deny(self, mock_supervisor_comms, monkeypatch): + from unittest.mock import MagicMock + + from airflow.sdk.execution_time import context as ctx_module + + mock_supervisor_comms.send.return_value = ErrorResponse( + error=ErrorType.PERMISSION_DENIED, + detail={"conn_id": "denied_conn", "status_code": 403}, + ) + + later_backend = MagicMock(name="LaterBackend") + later_backend.get_connection.return_value = MagicMock(name="leaked_conn") + + monkeypatch.setattr( + "airflow.sdk.execution_time.supervisor.ensure_secrets_backend_loaded", + lambda: [ExecutionAPISecretsBackend(), later_backend], + ) + + with pytest.raises(AirflowSecretsBackendAccessDenied, match="connection 'denied_conn'"): + ctx_module._get_connection("denied_conn") + + later_backend.get_connection.assert_not_called() + + def test_get_variable_does_not_fall_through_after_deny(self, mock_supervisor_comms, monkeypatch): + from unittest.mock import MagicMock + + from airflow.sdk.execution_time import context as ctx_module + + mock_supervisor_comms.send.return_value = ErrorResponse( + error=ErrorType.PERMISSION_DENIED, + detail={"key": "denied_var", "status_code": 403}, + ) + + later_backend = MagicMock(name="LaterBackend") + later_backend.get_variable.return_value = "leaked-value" + + monkeypatch.setattr( + "airflow.sdk.execution_time.supervisor.ensure_secrets_backend_loaded", + lambda: [ExecutionAPISecretsBackend(), later_backend], + ) + + with pytest.raises(AirflowSecretsBackendAccessDenied, match="variable 'denied_var'"): + ctx_module._get_variable("denied_var", deserialize_json=False) + + later_backend.get_variable.assert_not_called() + + @pytest.mark.asyncio + async def test_async_get_connection_does_not_fall_through_after_deny( + self, mock_supervisor_comms, monkeypatch + ): + from unittest.mock import MagicMock + + from airflow.sdk.execution_time import context as ctx_module + + async def asend(*_args, **_kwargs): + return ErrorResponse( + error=ErrorType.PERMISSION_DENIED, + detail={"conn_id": "denied_conn", "status_code": 403}, + ) + + mock_supervisor_comms.asend = asend + + later_backend = MagicMock(name="LaterBackend") + # The dispatcher prefers aget_connection if present; mock both for safety. + later_backend.aget_connection = MagicMock(return_value=MagicMock(name="leaked_conn")) + later_backend.get_connection = MagicMock(return_value=MagicMock(name="leaked_conn")) + + monkeypatch.setattr( + "airflow.sdk.execution_time.supervisor.ensure_secrets_backend_loaded", + lambda: [ExecutionAPISecretsBackend(), later_backend], + ) + + with pytest.raises(AirflowSecretsBackendAccessDenied, match="connection 'denied_conn'"): + await ctx_module._async_get_connection("denied_conn") + + later_backend.aget_connection.assert_not_called() + later_backend.get_connection.assert_not_called() + class TestContextDetection: """Test context detection in ensure_secrets_backend_loaded."""