From 0bcf88e12f285b3cc7b18e0a66cc87ba5eb02171 Mon Sep 17 00:00:00 2001 From: Jun Yeong Kim Date: Tue, 5 May 2026 18:37:46 +0900 Subject: [PATCH 1/2] Task SDK: Add Variable.keys() to list variable keys by prefix --- .../execution_api/datamodels/variable.py | 6 ++++ .../execution_api/routes/__init__.py | 1 + .../execution_api/routes/variables.py | 29 ++++++++++++++- .../execution_api/versions/__init__.py | 5 +++ .../execution_api/versions/v2026_04_28.py | 28 +++++++++++++++ .../src/airflow/dag_processing/processor.py | 8 +++++ .../src/airflow/jobs/triggerer_job_runner.py | 7 ++++ .../versions/head/test_variables.py | 28 +++++++++++++++ .../versions/v2026_04_28/__init__.py | 16 +++++++++ .../versions/v2026_04_28/test_variables.py | 35 +++++++++++++++++++ task-sdk/src/airflow/sdk/api/client.py | 9 +++++ .../airflow/sdk/api/datamodels/_generated.py | 11 ++++++ .../src/airflow/sdk/definitions/variable.py | 11 ++++++ .../src/airflow/sdk/execution_time/comms.py | 12 +++++++ .../src/airflow/sdk/execution_time/context.py | 11 ++++++ .../sdk/execution_time/request_handlers.py | 13 +++++++ .../airflow/sdk/execution_time/supervisor.py | 4 +++ .../task_sdk/definitions/test_variables.py | 32 ++++++++++++++++- .../execution_time/test_supervisor.py | 15 ++++++++ 19 files changed, 279 insertions(+), 2 deletions(-) create mode 100644 airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_04_28.py create mode 100644 airflow-core/tests/unit/api_fastapi/execution_api/versions/v2026_04_28/__init__.py create mode 100644 airflow-core/tests/unit/api_fastapi/execution_api/versions/v2026_04_28/test_variables.py diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/variable.py b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/variable.py index fd49a5eae46d6..b9403a878d277 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/variable.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/variable.py @@ -34,3 +34,9 @@ class VariablePostBody(StrictBaseModel): value: str | None = Field(alias="val") description: str | None = Field(default=None) + + +class VariableKeysResponse(StrictBaseModel): + """Variable keys schema for list responses.""" + + keys: list[str] diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py index 06f07aee82389..864bb02b1ca8e 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py @@ -51,6 +51,7 @@ authenticated_router.include_router( task_reschedules.router, prefix="/task-reschedules", tags=["Task Reschedules"] ) +authenticated_router.include_router(variables.keys_router, prefix="/variables", tags=["Variables"]) authenticated_router.include_router(variables.router, prefix="/variables", tags=["Variables"]) authenticated_router.include_router(xcoms.router, prefix="/xcoms", tags=["XComs"]) authenticated_router.include_router(hitl.router, prefix="/hitlDetails", tags=["Human in the Loop"]) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/variables.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/variables.py index a0a7cb56045c1..b0cb7d5eaf02f 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/variables.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/variables.py @@ -20,9 +20,12 @@ import logging from typing import Annotated -from fastapi import APIRouter, Depends, HTTPException, Path, Request, status +from fastapi import APIRouter, Depends, HTTPException, Path, Query, Request, status +from sqlalchemy import select +from airflow.api_fastapi.common.db.common import SessionDep from airflow.api_fastapi.execution_api.datamodels.variable import ( + VariableKeysResponse, VariablePostBody, VariableResponse, ) @@ -59,6 +62,8 @@ async def has_variable_access( dependencies=[Depends(has_variable_access)], ) +keys_router = APIRouter() + log = logging.getLogger(__name__) @@ -120,3 +125,25 @@ def delete_variable( ): """Delete an Airflow Variable.""" Variable.delete(key=variable_key, team_name=team_name) + + +@keys_router.get( + "/keys", + responses={ + status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"}, + }, +) +def get_variable_keys( + session: SessionDep, + team_name: Annotated[str | None, Depends(get_team_name_dep)] = None, + prefix: Annotated[str | None, Query()] = None, +) -> VariableKeysResponse: + """Get Airflow Variable keys, optionally filtered by prefix.""" + stmt = select(Variable.key) + if prefix is not None: + stmt = stmt.where(Variable.key.startswith(prefix)) + if team_name is not None: + stmt = stmt.where(Variable.team_name == team_name) + + keys = session.scalars(stmt).all() + return VariableKeysResponse(keys=list(keys)) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py index dfa27f53ebd91..31ce8b8ff8dcf 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py @@ -41,11 +41,16 @@ RemoveUpstreamMapIndexesField, ) from airflow.api_fastapi.execution_api.versions.v2026_04_17 import AddStateEndpoints, AddTeamNameField +from airflow.api_fastapi.execution_api.versions.v2026_04_28 import AddVariableKeysEndpoint from airflow.api_fastapi.execution_api.versions.v2026_06_16 import AddRetryPolicyFields bundle = VersionBundle( HeadVersion(), Version("2026-06-16", AddRetryPolicyFields), + Version( + "2026-04-28", + AddVariableKeysEndpoint, + ), Version( "2026-04-17", AddTeamNameField, diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_04_28.py b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_04_28.py new file mode 100644 index 0000000000000..0bc300a499837 --- /dev/null +++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_04_28.py @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from cadwyn import VersionChange, endpoint + + +class AddVariableKeysEndpoint(VersionChange): + """Add GET /variables/keys endpoint for listing variable keys with optional prefix filter.""" + + description = __doc__ + + instructions_to_migrate_to_previous_version = (endpoint("/variables/keys", ["GET"]).didnt_exist,) diff --git a/airflow-core/src/airflow/dag_processing/processor.py b/airflow-core/src/airflow/dag_processing/processor.py index aa9f07411f87d..d7c0a9d2b59fc 100644 --- a/airflow-core/src/airflow/dag_processing/processor.py +++ b/airflow-core/src/airflow/dag_processing/processor.py @@ -50,6 +50,7 @@ GetTaskStates, GetTICount, GetVariable, + GetVariableKeys, GetXCom, GetXComCount, GetXComSequenceItem, @@ -61,6 +62,7 @@ PrevSuccessfulDagRunResult, PutVariable, TaskStatesResult, + VariableKeysResult, VariableResult, XComCountResponse, XComResult, @@ -128,6 +130,7 @@ class DagFileParsingResult(BaseModel): DagFileParsingResult | GetConnection | GetVariable + | GetVariableKeys | PutVariable | GetTaskStates | GetTICount @@ -147,6 +150,7 @@ class DagFileParsingResult(BaseModel): DagFileParseRequest | ConnectionResult | VariableResult + | VariableKeysResult | TaskStatesResult | PreviousDagRunResult | PreviousTIResult @@ -628,6 +632,10 @@ def _handle_request(self, msg: ToManager, log: FilteringBoundLogger, req_id: int dump_opts = {"exclude_unset": True} else: resp = var + elif isinstance(msg, GetVariableKeys): + from airflow.sdk.execution_time.request_handlers import handle_get_variable_keys + + resp, dump_opts = handle_get_variable_keys(self.client, msg) elif isinstance(msg, PutVariable): self.client.variables.set(msg.key, msg.value, msg.description) elif isinstance(msg, DeleteVariable): diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py b/airflow-core/src/airflow/jobs/triggerer_job_runner.py index 9a22e61ad8a03..a7e72d17f7d1f 100644 --- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py +++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py @@ -71,6 +71,7 @@ GetTaskStates, GetTICount, GetVariable, + GetVariableKeys, GetXCom, MaskSecret, OKResponse, @@ -79,6 +80,7 @@ TaskStatesResult, TICount, UpdateHITLDetail, + VariableKeysResult, VariableResult, XComResult, _new_encoder, @@ -87,6 +89,7 @@ from airflow.sdk.execution_time.request_handlers import ( handle_get_connection, handle_get_variable, + handle_get_variable_keys, handle_mask_secret, ) from airflow.sdk.execution_time.supervisor import WatchedSubprocess, make_buffered_socket_reader @@ -299,6 +302,7 @@ def from_api_response(cls, response: HITLDetailResponse) -> HITLDetailResponseRe | messages.TriggerStateSync | ConnectionResult | VariableResult + | VariableKeysResult | XComResult | DagRunStateResult | DRCount @@ -320,6 +324,7 @@ def from_api_response(cls, response: HITLDetailResponse) -> HITLDetailResponseRe | GetConnection | DeleteVariable | GetVariable + | GetVariableKeys | PutVariable | DeleteXCom | GetXCom @@ -516,6 +521,8 @@ def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger, r resp = self.client.variables.delete(msg.key) elif isinstance(msg, GetVariable): resp, dump_opts = handle_get_variable(self.client, msg) + elif isinstance(msg, GetVariableKeys): + resp, dump_opts = handle_get_variable_keys(self.client, msg) elif isinstance(msg, PutVariable): self.client.variables.set(msg.key, msg.value, msg.description) elif isinstance(msg, DeleteXCom): diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_variables.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_variables.py index fe7611636358d..391530dc8fb5b 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_variables.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_variables.py @@ -244,6 +244,34 @@ def test_post_variable_access_denied(self, client, caplog): assert any(msg.startswith("Checking write access for task instance") for msg in caplog.messages) +class TestGetVariableKeys: + @pytest.mark.parametrize( + ("prefix", "expected_keys"), + [ + pytest.param(None, {"prod_db_url", "prod_api_key", "dev_debug"}, id="no-prefix"), + pytest.param("prod_", {"prod_db_url", "prod_api_key"}, id="with-prefix"), + pytest.param("staging_", set(), id="no-match"), + ], + ) + def test_get_variable_keys(self, client, session, prefix, expected_keys): + Variable.set(key="prod_db_url", value="postgres://...", session=session) + Variable.set(key="prod_api_key", value="secret", session=session) + Variable.set(key="dev_debug", value="true", session=session) + session.commit() + + params = {"prefix": prefix} if prefix is not None else {} + response = client.get("/execution/variables/keys", params=params) + + assert response.status_code == 200 + assert set(response.json()["keys"]) == expected_keys + + def test_get_variable_keys_empty_db(self, client): + response = client.get("/execution/variables/keys") + + assert response.status_code == 200 + assert response.json() == {"keys": []} + + class TestDeleteVariable: @pytest.mark.parametrize( ("keys_to_create", "key_to_delete"), diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2026_04_28/__init__.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2026_04_28/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2026_04_28/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2026_04_28/test_variables.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2026_04_28/test_variables.py new file mode 100644 index 0000000000000..19654d255a32a --- /dev/null +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2026_04_28/test_variables.py @@ -0,0 +1,35 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import pytest + +pytestmark = pytest.mark.db_test + + +@pytest.fixture +def old_ver_client(client): + """Last released execution API before `GET /variables/keys` was added.""" + client.headers["Airflow-API-Version"] = "2026-04-17" + return client + + +def test_variable_keys_endpoint_not_available_in_previous_version(old_ver_client): + response = old_ver_client.get("/execution/variables/keys") + + assert response.status_code == 404 diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index 54927794bf17e..bfa90d72e9b8c 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -69,6 +69,7 @@ TITerminalStatePayload, TriggerDAGRunPayload, ValidationError as RemoteValidationError, + VariableKeysResponse, VariablePostBody, VariableResponse, XComResponse, @@ -477,6 +478,14 @@ def delete( # decouple from the server response string return OKResponse(ok=True) + def keys(self, prefix: str | None = None) -> VariableKeysResponse: + """List variable keys from the API server, optionally filtered by key prefix.""" + params: dict[str, str] = {} + if prefix is not None: + params["prefix"] = prefix + resp = self.client.get("variables/keys", params=params) + return VariableKeysResponse.model_validate_json(resp.read()) + class XComOperations: __slots__ = ("client",) diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py index b5b100154c389..e9048b12b2d01 100644 --- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -442,6 +442,17 @@ class ValidationError(BaseModel): ctx: Annotated[dict[str, Any] | None, Field(title="Context")] = None +class VariableKeysResponse(BaseModel): + """ + Variable keys schema for list responses. + """ + + model_config = ConfigDict( + extra="forbid", + ) + keys: Annotated[list[str], Field(title="Keys")] + + class VariablePostBody(BaseModel): """ Request body schema for creating variables. diff --git a/task-sdk/src/airflow/sdk/definitions/variable.py b/task-sdk/src/airflow/sdk/definitions/variable.py index 2e4c9aae3ca0f..527788269620d 100644 --- a/task-sdk/src/airflow/sdk/definitions/variable.py +++ b/task-sdk/src/airflow/sdk/definitions/variable.py @@ -67,6 +67,17 @@ def set(cls, key: str, value: Any, description: str | None = None, serialize_jso except AirflowRuntimeError as e: log.exception(e) + @classmethod + def keys(cls, prefix: str | None = None) -> list[str]: + """ + Return all Variable keys that start with the given prefix. + + :param prefix: Optional key prefix to filter by. If None, all Variable keys are returned. + """ + from airflow.sdk.execution_time.context import _get_variable_keys + + return _get_variable_keys(prefix=prefix) + @classmethod def delete(cls, key: str) -> None: from airflow.sdk.exceptions import AirflowRuntimeError diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index 87c7881333ad4..0247ba0ce3e7b 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -545,6 +545,11 @@ def from_variable_response(cls, variable_response: VariableResponse) -> Variable return cls(**variable_response.model_dump(exclude_defaults=True), type="VariableResult") +class VariableKeysResult(BaseModel): + keys: list[str] + type: Literal["VariableKeysResult"] = "VariableKeysResult" + + class DagRunResult(DagRun): type: Literal["DagRunResult"] = "DagRunResult" @@ -728,6 +733,7 @@ def from_api_response(cls, dag_response: DagResponse) -> DagResult: | TaskBreadcrumbsResult | TaskStatesResult | VariableResult + | VariableKeysResult | XComCountResponse | XComResult | XComSequenceIndexResult @@ -862,6 +868,11 @@ class GetVariable(BaseModel): type: Literal["GetVariable"] = "GetVariable" +class GetVariableKeys(BaseModel): + prefix: str | None = None + type: Literal["GetVariableKeys"] = "GetVariableKeys" + + class PutVariable(BaseModel): key: str value: str | None @@ -1061,6 +1072,7 @@ class GetDag(BaseModel): | GetTaskBreadcrumbs | GetTaskStates | GetVariable + | GetVariableKeys | GetXCom | GetXComCount | GetXComSequenceItem diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py b/task-sdk/src/airflow/sdk/execution_time/context.py index 66c1f3aa8b7eb..757e56f56db69 100644 --- a/task-sdk/src/airflow/sdk/execution_time/context.py +++ b/task-sdk/src/airflow/sdk/execution_time/context.py @@ -278,6 +278,17 @@ def _get_variable(key: str, deserialize_json: bool) -> Any: ) +def _get_variable_keys(prefix: str | None = None) -> list[str]: + from airflow.sdk.execution_time.comms import GetVariableKeys, VariableKeysResult + from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + + msg = SUPERVISOR_COMMS.send(GetVariableKeys(prefix=prefix)) + if not isinstance(msg, VariableKeysResult): + return [] + + return msg.keys + + def _set_variable(key: str, value: Any, description: str | None = None, serialize_json: bool = False) -> None: # TODO: This should probably be moved to a separate module like `airflow.sdk.execution_time.comms` # or `airflow.sdk.execution_time.variable` diff --git a/task-sdk/src/airflow/sdk/execution_time/request_handlers.py b/task-sdk/src/airflow/sdk/execution_time/request_handlers.py index eed3e840e395b..3b528c2e71a15 100644 --- a/task-sdk/src/airflow/sdk/execution_time/request_handlers.py +++ b/task-sdk/src/airflow/sdk/execution_time/request_handlers.py @@ -31,13 +31,16 @@ from airflow.sdk.api.datamodels._generated import ( ConnectionResponse, + VariableKeysResponse, VariableResponse, ) from airflow.sdk.execution_time.comms import ( ConnectionResult, GetConnection, GetVariable, + GetVariableKeys, MaskSecret, + VariableKeysResult, VariableResult, ) from airflow.sdk.log import mask_secret @@ -70,6 +73,16 @@ def handle_get_variable(client: Client, msg: GetVariable) -> tuple[BaseModel | N return var, {} +def handle_get_variable_keys( + client: Client, msg: GetVariableKeys +) -> tuple[BaseModel | None, dict[str, bool]]: + """Fetch variable keys filtered by prefix.""" + result = client.variables.keys(prefix=msg.prefix) + if not isinstance(result, VariableKeysResponse): + return result, {} + return VariableKeysResult(keys=result.keys), {"exclude_unset": True} + + def handle_mask_secret(msg: MaskSecret) -> None: """Register a value with the secrets masker.""" mask_secret(msg.value, msg.name) diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 375c5a9e30b8e..356fe9824e5c8 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -90,6 +90,7 @@ GetTaskStates, GetTICount, GetVariable, + GetVariableKeys, GetXCom, GetXComCount, GetXComSequenceItem, @@ -124,6 +125,7 @@ from airflow.sdk.execution_time.request_handlers import ( handle_get_connection, handle_get_variable, + handle_get_variable_keys, handle_mask_secret, ) @@ -1423,6 +1425,8 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id: resp, dump_opts = handle_get_connection(self.client, msg) elif isinstance(msg, GetVariable): resp, dump_opts = handle_get_variable(self.client, msg) + elif isinstance(msg, GetVariableKeys): + resp, dump_opts = handle_get_variable_keys(self.client, msg) elif isinstance(msg, GetXCom): xcom = self.client.xcoms.get( msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.map_index, msg.include_prior_dates diff --git a/task-sdk/tests/task_sdk/definitions/test_variables.py b/task-sdk/tests/task_sdk/definitions/test_variables.py index 3717f834735ff..06f5ba1ae5437 100644 --- a/task-sdk/tests/task_sdk/definitions/test_variables.py +++ b/task-sdk/tests/task_sdk/definitions/test_variables.py @@ -25,7 +25,7 @@ from airflow.sdk import Variable from airflow.sdk.configuration import initialize_secrets_backends -from airflow.sdk.execution_time.comms import PutVariable, VariableResult +from airflow.sdk.execution_time.comms import GetVariableKeys, PutVariable, VariableKeysResult, VariableResult from airflow.sdk.execution_time.secrets import DEFAULT_SECRETS_SEARCH_PATH_WORKERS from tests_common.test_utils.config import conf_vars @@ -90,6 +90,36 @@ def test_var_set(self, key, value, description, serialize_json, mock_supervisor_ ) +class TestVariableKeys: + @pytest.mark.parametrize( + ("prefix", "keys"), + [ + pytest.param( + None, + ["prod_db", "prod_api", "dev_debug"], + id="all", + ), + pytest.param( + "prod_", + ["prod_db", "prod_api"], + id="with-prefix", + ), + pytest.param( + "nonexistent_", + [], + id="empty-result", + ), + ], + ) + def test_keys(self, prefix, keys, mock_supervisor_comms): + mock_supervisor_comms.send.return_value = VariableKeysResult(keys=keys) + + results = Variable.keys(prefix=prefix) + + mock_supervisor_comms.send.assert_called_once_with(msg=GetVariableKeys(prefix=prefix)) + assert results == keys + + class TestVariableFromSecrets: def test_var_get_from_secrets_found(self, mock_supervisor_comms, tmp_path): """Tests getting a variable from secrets backend.""" diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index 3695af1fff592..90fc42fca5c9e 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -97,6 +97,7 @@ GetTaskStates, GetTICount, GetVariable, + GetVariableKeys, GetXCom, GetXComCount, GetXComSequenceItem, @@ -127,6 +128,7 @@ TriggerDagRun, UpdateHITLDetail, ValidateInletsAndOutlets, + VariableKeysResult, VariableResult, XComCountResponse, XComResult, @@ -1548,6 +1550,19 @@ class RequestTestCase: ), expected_body={"ok": True, "type": "OKResponse"}, ), + RequestTestCase( + message=GetVariableKeys(prefix="test_"), + test_id="get_variable_keys", + client_mock=ClientMock( + method_path="variables.keys", + kwargs={"prefix": "test_"}, + response=VariableKeysResult(keys=["test_key"]), + ), + expected_body={ + "keys": ["test_key"], + "type": "VariableKeysResult", + }, + ), RequestTestCase( message=DeferTask(next_method="execute_callback", classpath="my-classpath"), test_id="patch_task_instance_to_deferred", From 3002078a55bbc6727f6416caa2437dcf6ac4e067 Mon Sep 17 00:00:00 2001 From: Jun Yeong Kim Date: Wed, 6 May 2026 01:30:37 +0900 Subject: [PATCH 2/2] Task SDK: Make Variable.keys() return a lazy proxy MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per dev list feedback, wrap the result in lazy_object_proxy.Proxy so the Execution API call only happens on first access (iteration, indexing, len, etc.) and is cached for subsequent accesses. Matches the pattern already used for template context values. Also clarifies in the docstring that only keys stored in the metadata database are returned — secrets backends are not consulted. --- .../src/airflow/sdk/definitions/variable.py | 12 +++++++++--- .../task_sdk/definitions/test_variables.py | 19 ++++++++++++++++++- 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/task-sdk/src/airflow/sdk/definitions/variable.py b/task-sdk/src/airflow/sdk/definitions/variable.py index 527788269620d..eed381cd7b201 100644 --- a/task-sdk/src/airflow/sdk/definitions/variable.py +++ b/task-sdk/src/airflow/sdk/definitions/variable.py @@ -70,13 +70,19 @@ def set(cls, key: str, value: Any, description: str | None = None, serialize_jso @classmethod def keys(cls, prefix: str | None = None) -> list[str]: """ - Return all Variable keys that start with the given prefix. + Return Variable keys that start with the given prefix. - :param prefix: Optional key prefix to filter by. If None, all Variable keys are returned. + The keys are fetched lazily on first access (iteration, indexing, len, etc.) + and cached for subsequent access. Only keys stored in the metadata database + are returned — secrets backends are not consulted. + + :param prefix: Optional key prefix to filter by. If None, all keys are returned. """ + import lazy_object_proxy + from airflow.sdk.execution_time.context import _get_variable_keys - return _get_variable_keys(prefix=prefix) + return lazy_object_proxy.Proxy(lambda: _get_variable_keys(prefix=prefix)) @classmethod def delete(cls, key: str) -> None: diff --git a/task-sdk/tests/task_sdk/definitions/test_variables.py b/task-sdk/tests/task_sdk/definitions/test_variables.py index 06f5ba1ae5437..beaefab83d917 100644 --- a/task-sdk/tests/task_sdk/definitions/test_variables.py +++ b/task-sdk/tests/task_sdk/definitions/test_variables.py @@ -116,8 +116,25 @@ def test_keys(self, prefix, keys, mock_supervisor_comms): results = Variable.keys(prefix=prefix) + # keys() is lazy — no API call until the proxy is accessed + mock_supervisor_comms.send.assert_not_called() + + materialized = list(results) + mock_supervisor_comms.send.assert_called_once_with(msg=GetVariableKeys(prefix=prefix)) - assert results == keys + assert materialized == keys + + def test_keys_cached_after_first_access(self, mock_supervisor_comms): + mock_supervisor_comms.send.return_value = VariableKeysResult(keys=["a", "b"]) + + results = Variable.keys(prefix="x_") + + # Multiple accesses should only trigger the API call once + list(results) + list(results) + len(results) + + mock_supervisor_comms.send.assert_called_once_with(msg=GetVariableKeys(prefix="x_")) class TestVariableFromSecrets: