diff --git a/airflow-core/src/airflow/config_templates/config.yml b/airflow-core/src/airflow/config_templates/config.yml index 57ddd14a6cb2c..0712b6746115f 100644 --- a/airflow-core/src/airflow/config_templates/config.yml +++ b/airflow-core/src/airflow/config_templates/config.yml @@ -1889,6 +1889,17 @@ workers: sensitive: true example: ~ default: "" + state_backend: + description: | + Full class name of a custom worker-side state backend. When set, task state values are + routed through this backend so large payloads or credentialed storage stay on worker + infrastructure. The Execution API still records a reference string in the database. + + Leave empty (default) to use the standard path through the task sdk supervisor. + version_added: 3.3.0 + type: string + example: "mypackage.state.S3StateBackend" + default: "" min_heartbeat_interval: description: | The minimum interval (in seconds) at which the worker checks the task instance's diff --git a/shared/state/src/airflow_shared/state/__init__.py b/shared/state/src/airflow_shared/state/__init__.py index bfce8db332855..7aa9fcba8372d 100644 --- a/shared/state/src/airflow_shared/state/__init__.py +++ b/shared/state/src/airflow_shared/state/__init__.py @@ -47,9 +47,25 @@ class TaskScope: @dataclass(frozen=True) class AssetScope: - """Identifies the state namespace for an asset.""" + """ + Identifies the state namespace for an asset. + + Server-side backends receive ``asset_id``. Worker-side backends receive ``name`` or ``uri`` + since workers do not have access to the integer ``asset_id``. + + Note: ``name`` and ``uri`` are not guaranteed to be unique over time — if an asset is + deactivated and a new one created with the same name, both share the same ``name`` value. + State for inactive assets is cleaned up by the orphan GC pass; until then, stale rows exist + in the DB but cannot be written to (the Execution API resolver filters to active assets only). + """ + + asset_id: int | None = None + name: str | None = None + uri: str | None = None - asset_id: int + def __post_init__(self) -> None: + if self.asset_id is None and self.name is None and self.uri is None: + raise ValueError("AssetScope requires at least one of: asset_id, name, or uri") StateScope = TaskScope | AssetScope @@ -186,3 +202,54 @@ def cleanup(self) -> None: retention policy. The backend is responsible for reading any relevant config (e.g. ``[state_store] default_retention_days``) and deciding what to delete. """ + + def serialize_task_state_to_ref(self, *, value: str, key: str, ti_id: str) -> str: + """ + Serialize a task state value before it is sent to the execution API for db persistence. + + Called by ``TaskStateAccessor.set()`` on the worker. The return value is what gets + stored in the DB — typically a reference path (e.g. an S3 key) rather than the + actual value. Default: return ``value`` unchanged. + + The returned reference must be deterministic — given the same ``ti_id`` and ``key`` it + must always return the same string. Do not use timestamps or random UUIDs as part of + the reference, otherwise ``delete()``/``clear()`` cannot reconstruct it and the external + object will be orphaned. + """ + return value + + def deserialize_task_state_from_ref(self, stored: str) -> str: + """ + Resolve a stored task state string back to the actual value. + + Called by ``TaskStateAccessor.get()`` after the stored string is retrieved from + the execution API. Default: return ``stored`` unchanged. + """ + return stored + + def serialize_asset_state_to_ref(self, *, value: str, key: str, asset_ref: str) -> str: + """ + Serialize an asset state value before it is sent to the Execution API for db persistence. + + Called by ``AssetStateAccessor.set()`` on the worker. The return value is what gets + stored in the DB — typically a reference path rather than the actual value. + Default: return ``value`` unchanged. + + ``asset_ref`` is either the asset name or URI, depending on how the accessor was + constructed. It may be a URI string if the task inlet was declared as ``AssetUriRef``. + + The returned reference must be deterministic — given the same ``asset_ref`` and ``key`` it + must always return the same string. Do not use timestamps or random UUIDs as part of + the reference, otherwise ``delete()``/``clear()`` cannot reconstruct it and the external + object will be orphaned. + """ + return value + + def deserialize_asset_state_from_ref(self, stored: str) -> str: + """ + Resolve a stored asset state string back to the actual value. + + Called by ``AssetStateAccessor.get()`` after the stored string is retrieved from + the Execution API. Default: return ``stored`` unchanged. + """ + return stored diff --git a/shared/state/tests/state/test_state.py b/shared/state/tests/state/test_state.py index 47bce18a69eab..1ea31194e2788 100644 --- a/shared/state/tests/state/test_state.py +++ b/shared/state/tests/state/test_state.py @@ -18,7 +18,22 @@ import pytest -from airflow_shared.state import BaseStateBackend, StateScope +from airflow_shared.state import AssetScope, BaseStateBackend, StateScope + + +class TestAssetScope: + def test_requires_at_least_one_identifier(self): + with pytest.raises(ValueError, match="at least one of"): + AssetScope() + + def test_asset_id_alone_is_valid(self): + AssetScope(asset_id=1) + + def test_name_alone_is_valid(self): + AssetScope(name="my_asset") + + def test_uri_alone_is_valid(self): + AssetScope(uri="s3://bucket/key") class TestBaseStateBackend: @@ -70,3 +85,69 @@ def test_abstract_methods_cover_full_interface(self): """BaseStateBackend enforces all 8 sync+async methods as abstract.""" expected = {"get", "set", "delete", "clear", "aget", "aset", "adelete", "aclear"} assert BaseStateBackend.__abstractmethods__ == expected + + def test_task_state_serialize_deserialize_round_trip(self, backend): + original = "app_1234" + serialized = backend.serialize_task_state_to_ref(value=original, key="job_id", ti_id="abc-123") + deserialized = backend.deserialize_task_state_from_ref(serialized) + assert deserialized == original + + def test_custom_backend_overrides_task_state_ser_deser(self): + class MyBackend(BaseStateBackend): + def get(self, scope, key): ... + def set(self, scope, key, value): ... + def delete(self, scope, key): ... + def clear(self, scope, *, all_map_indices=False): ... + async def aget(self, scope, key): ... + async def aset(self, scope, key, value): ... + async def adelete(self, scope, key): ... + async def aclear(self, scope, *, all_map_indices=False): ... + + def serialize_task_state_to_ref(self, *, value, key, ti_id): + return f"s3://bucket/{ti_id}/{key}" + + def deserialize_task_state_from_ref(self, stored): + return f"fetched:{stored}" + + b = MyBackend() + assert b.serialize_task_state_to_ref(value="app_1234", key="job_id", ti_id="abc-123") == ( + "s3://bucket/abc-123/job_id" + ) + assert ( + b.deserialize_task_state_from_ref("s3://bucket/abc-123/job_id") + == "fetched:s3://bucket/abc-123/job_id" + ) + + def test_asset_state_serialize_deserialize_round_trip(self, backend): + original = "2026-05-01" + serialized = backend.serialize_asset_state_to_ref( + value="2026-05-01", key="watermark", asset_ref="my_asset" + ) + deserialized = backend.deserialize_asset_state_from_ref(serialized) + assert deserialized == original + + def test_custom_backend_overrides_asset_state_ser_deser(self): + class MyBackend(BaseStateBackend): + def get(self, scope, key): ... + def set(self, scope, key, value): ... + def delete(self, scope, key): ... + def clear(self, scope, *, all_map_indices=False): ... + async def aget(self, scope, key): ... + async def aset(self, scope, key, value): ... + async def adelete(self, scope, key): ... + async def aclear(self, scope, *, all_map_indices=False): ... + + def serialize_asset_state_to_ref(self, *, value, key, asset_ref): + return f"s3://bucket/assets/{asset_ref}/{key}" + + def deserialize_asset_state_from_ref(self, stored): + return f"resolved:{stored}" + + b = MyBackend() + assert b.serialize_asset_state_to_ref(value="2026-05-01", key="watermark", asset_ref="my_asset") == ( + "s3://bucket/assets/my_asset/watermark" + ) + assert ( + b.deserialize_asset_state_from_ref("s3://bucket/assets/my_asset/watermark") + == "resolved:s3://bucket/assets/my_asset/watermark" + ) diff --git a/task-sdk/pyproject.toml b/task-sdk/pyproject.toml index 89a17fb52c3c9..4fc9f3d586a52 100644 --- a/task-sdk/pyproject.toml +++ b/task-sdk/pyproject.toml @@ -147,6 +147,7 @@ path = "src/airflow/sdk/__init__.py" "../shared/listeners/src/airflow_shared/listeners" = "src/airflow/sdk/_shared/listeners" "../shared/plugins_manager/src/airflow_shared/plugins_manager" = "src/airflow/sdk/_shared/plugins_manager" "../shared/providers_discovery/src/airflow_shared/providers_discovery" = "src/airflow/sdk/_shared/providers_discovery" +"../shared/state/src/airflow_shared/state" = "src/airflow/sdk/_shared/state" "../shared/template_rendering/src/airflow_shared/template_rendering" = "src/airflow/sdk/_shared/template_rendering" [tool.hatch.build.targets.wheel] @@ -240,6 +241,7 @@ apache-airflow = {workspace = true} apache-airflow-devel-common = {workspace = true} apache-airflow-providers-common-sql = {workspace = true} apache-airflow-providers-standard = {workspace = true} +apache-airflow-shared-state = {workspace = true} # To use: # @@ -316,6 +318,7 @@ shared_distributions = [ "apache-airflow-shared-secrets-backend", "apache-airflow-shared-secrets-masker", "apache-airflow-shared-serialization", + "apache-airflow-shared-state", "apache-airflow-shared-timezones", "apache-airflow-shared-observability", "apache-airflow-shared-plugins-manager", diff --git a/task-sdk/src/airflow/sdk/_shared/state b/task-sdk/src/airflow/sdk/_shared/state new file mode 120000 index 0000000000000..752da6322068e --- /dev/null +++ b/task-sdk/src/airflow/sdk/_shared/state @@ -0,0 +1 @@ +../../../../../shared/state/src/airflow_shared/state \ No newline at end of file diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index 24074824b8840..df08b743a1019 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -93,6 +93,7 @@ OKResponse, PreviousDagRunResult, PreviousTIResult, + RescheduleTask, SkipDownstreamTasks, TaskRescheduleStartDate, TICount, @@ -104,8 +105,6 @@ from datetime import datetime from typing import ParamSpec - from airflow.sdk.execution_time.comms import RescheduleTask - P = ParamSpec("P") T = TypeVar("T") diff --git a/task-sdk/src/airflow/sdk/configuration.py b/task-sdk/src/airflow/sdk/configuration.py index 4e438a2cbf737..fb32f990c5880 100644 --- a/task-sdk/src/airflow/sdk/configuration.py +++ b/task-sdk/src/airflow/sdk/configuration.py @@ -32,6 +32,7 @@ configure_parser_from_configuration_description, expand_env_var, ) +from airflow.sdk._shared.module_loading import import_string from airflow.sdk.execution_time.secrets import _SERVER_DEFAULT_SECRETS_SEARCH_PATH log = logging.getLogger(__name__) @@ -236,8 +237,6 @@ def initialize_secrets_backends( Uses SDK's conf instead of Core's conf. """ - from airflow.sdk._shared.module_loading import import_string - backend_list = [] worker_mode = False # Determine worker mode - if default_backends is not the server default, it's worker mode diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py b/task-sdk/src/airflow/sdk/execution_time/context.py index afc5a120c0c25..1d5b0694667a0 100644 --- a/task-sdk/src/airflow/sdk/execution_time/context.py +++ b/task-sdk/src/airflow/sdk/execution_time/context.py @@ -51,6 +51,7 @@ from typing_extensions import Self from airflow.sdk import Variable + from airflow.sdk._shared.state import TaskScope from airflow.sdk.bases.operator import BaseOperator from airflow.sdk.definitions.connection import Connection from airflow.sdk.definitions.context import Context @@ -65,6 +66,7 @@ ReceiveMsgType, VariableResult, ) + from airflow.sdk.state import BaseStateBackend from airflow.sdk.types import OutletEventAccessorsProtocol @@ -440,11 +442,29 @@ def get(self, key, default: Any = NOTSET) -> Any: raise +@cache +def _get_worker_state_backend() -> BaseStateBackend | None: + """Return the configured worker-side state backend, instantiated once and cached.""" + class_name = conf.get("workers", "state_backend", fallback="") + if not class_name: + return None + from airflow.sdk._shared.module_loading import import_string + + try: + return import_string(class_name)() + except (ImportError, AttributeError) as e: + raise ValueError( + f"Could not load worker state backend {class_name!r}. " + f"Check the [workers] state_backend config value. Error: {e}" + ) from e + + class TaskStateAccessor: """Accessor for task state scoped to the current task instance. Available as ``context['task_state']`` at task execution time.""" - def __init__(self, ti_id: UUID) -> None: + def __init__(self, ti_id: UUID, scope: TaskScope) -> None: self._ti_id = ti_id + self._scope = scope def __eq__(self, other: object) -> bool: if not isinstance(other, TaskStateAccessor): @@ -470,7 +490,11 @@ def get(self, key: str) -> str | None: if isinstance(resp, ErrorResponse) and resp.error != ErrorType.TASK_STATE_NOT_FOUND: raise AirflowRuntimeError(resp) if isinstance(resp, TaskStateResult): - return resp.value + stored = resp.value + # if custom backend is configured, the stored value in DB is a reference, fetch the actual value from + # custom backend using the reference + backend = _get_worker_state_backend() + return backend.deserialize_task_state_from_ref(stored) if backend else stored return None def set(self, key: str, value: str, *, retention: timedelta | None = None) -> None: @@ -495,14 +519,29 @@ def set(self, key: str, value: str, *, retention: timedelta | None = None) -> No else: days = conf.getint("state_store", "default_retention_days") expires_at = None if days <= 0 else now + timedelta(days=days) - SUPERVISOR_COMMS.send(SetTaskState(ti_id=self._ti_id, key=key, value=value, expires_at=expires_at)) + + # if custom backend is configured, store the value on the custom backend, and return the reference + # to the stored value to store in the DB + backend = _get_worker_state_backend() + stored = ( + backend.serialize_task_state_to_ref(value=value, key=key, ti_id=str(self._ti_id)) + if backend + else value + ) + + SUPERVISOR_COMMS.send(SetTaskState(ti_id=self._ti_id, key=key, value=stored, expires_at=expires_at)) def delete(self, key: str) -> None: """Delete a single key. No-op if the key does not exist.""" from airflow.sdk.execution_time.comms import DeleteTaskState from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + # cleanup the DB ref first, if backend cleanup fails after this, the ref is gone and + # deterministic keys are recoverable on next set(). SUPERVISOR_COMMS.send(DeleteTaskState(ti_id=self._ti_id, key=key)) + backend = _get_worker_state_backend() + if backend is not None: + backend.delete(self._scope, key) def clear(self, all_map_indices: bool = False) -> None: """ @@ -515,7 +554,23 @@ def clear(self, all_map_indices: bool = False) -> None: from airflow.sdk.execution_time.comms import ClearTaskState from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + # cleanup the DB ref first, if backend cleanup fails after this, the ref is gone and + # deterministic keys are recoverable on next set(). SUPERVISOR_COMMS.send(ClearTaskState(ti_id=self._ti_id, all_map_indices=all_map_indices)) + backend = _get_worker_state_backend() + if backend is not None: + backend.clear(self._scope, all_map_indices=all_map_indices) + + def _clear_backend_only(self) -> None: + """ + Clear external storage via the worker backend without sending a comms message. + + Used by clear_on_success: the server already clears DB rows as part of SucceedTask, + so the comms round-trip is redundant. + """ + backend = _get_worker_state_backend() + if backend is not None: + backend.clear(self._scope) class AssetStateAccessor: @@ -565,7 +620,11 @@ def get(self, key: str) -> str | None: if isinstance(resp, ErrorResponse) and resp.error != ErrorType.ASSET_STATE_NOT_FOUND: raise AirflowRuntimeError(resp) if isinstance(resp, AssetStateResult): - return resp.value + stored = resp.value + # if custom backend is configured, the stored value in DB is a reference, fetch the actual value from + # custom backend using the reference + backend = _get_worker_state_backend() + return backend.deserialize_asset_state_from_ref(stored) if backend else stored return None def set(self, key: str, value: str) -> None: @@ -573,15 +632,26 @@ def set(self, key: str, value: str) -> None: from airflow.sdk.execution_time.comms import SetAssetStateByName, SetAssetStateByUri, ToSupervisor from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + # if custom backend is configured, store the value on the custom backend, and return the reference + # to the stored value to store in the DB + backend = _get_worker_state_backend() + asset_ref = self._name or self._uri or "" + stored = ( + backend.serialize_asset_state_to_ref(value=value, key=key, asset_ref=asset_ref) + if backend + else value + ) + msg: ToSupervisor if self._name: - msg = SetAssetStateByName(name=self._name, key=key, value=value) + msg = SetAssetStateByName(name=self._name, key=key, value=stored) elif self._uri: - msg = SetAssetStateByUri(uri=self._uri, key=key, value=value) + msg = SetAssetStateByUri(uri=self._uri, key=key, value=stored) SUPERVISOR_COMMS.send(msg) def delete(self, key: str) -> None: """Delete a single key. No-op if the key does not exist.""" + from airflow.sdk._shared.state import AssetScope from airflow.sdk.execution_time.comms import ( DeleteAssetStateByName, DeleteAssetStateByUri, @@ -594,11 +664,21 @@ def delete(self, key: str) -> None: msg = DeleteAssetStateByName(name=self._name, key=key) elif self._uri: msg = DeleteAssetStateByUri(uri=self._uri, key=key) + # DB ref first: if backend cleanup fails after this, the ref is gone and + # deterministic keys are recoverable on next set(). SUPERVISOR_COMMS.send(msg) + backend = _get_worker_state_backend() + if backend is not None: + backend.delete(AssetScope(name=self._name, uri=self._uri), key) def clear(self) -> None: """Delete all state keys for this asset.""" - from airflow.sdk.execution_time.comms import ClearAssetStateByName, ClearAssetStateByUri, ToSupervisor + from airflow.sdk._shared.state import AssetScope + from airflow.sdk.execution_time.comms import ( + ClearAssetStateByName, + ClearAssetStateByUri, + ToSupervisor, + ) from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS msg: ToSupervisor @@ -607,6 +687,9 @@ def clear(self) -> None: elif self._uri: msg = ClearAssetStateByUri(uri=self._uri) SUPERVISOR_COMMS.send(msg) + backend = _get_worker_state_backend() + if backend is not None: + backend.clear(AssetScope(name=self._name, uri=self._uri)) class AssetStateAccessors: diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 9cb766c9b2d13..761ce3714f564 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -128,6 +128,7 @@ from airflow.sdk.execution_time.xcom import XCom from airflow.sdk.listener import get_listener_manager from airflow.sdk.observability.metrics import stats_utils +from airflow.sdk.state import TaskScope from airflow.sdk.timezone import coerce_datetime if TYPE_CHECKING: @@ -251,7 +252,15 @@ def get_template_context(self) -> Context: "value": VariableAccessor(deserialize_json=False), }, "conn": ConnectionAccessor(), - "task_state": TaskStateAccessor(ti_id=self.id), + "task_state": TaskStateAccessor( + ti_id=self.id, + scope=TaskScope( + dag_id=self.dag_id, + run_id=self.run_id, + task_id=self.task_id, + map_index=self.map_index if self.map_index is not None else -1, + ), + ), } if any(isinstance(i, (Asset, AssetNameRef, AssetUriRef, AssetAlias)) for i in self.task.inlets): self._cached_template_context["asset_state"] = AssetStateAccessors(self.task.inlets) @@ -1455,6 +1464,8 @@ def _handle_current_task_success( if conf.getboolean("state_store", "clear_on_success"): log.info("Task state will be cleared by the server because clear_on_success is enabled.") + context["task_state"]._clear_backend_only() + msg = SucceedTask( end_date=end_date, task_outlets=task_outlets, diff --git a/task-sdk/src/airflow/sdk/state.py b/task-sdk/src/airflow/sdk/state.py new file mode 100644 index 0000000000000..ac2a1126fe4bd --- /dev/null +++ b/task-sdk/src/airflow/sdk/state.py @@ -0,0 +1,25 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from airflow.sdk._shared.state import ( + AssetScope as AssetScope, + BaseStateBackend as BaseStateBackend, + TaskScope as TaskScope, +) diff --git a/task-sdk/tests/task_sdk/docs/test_public_api.py b/task-sdk/tests/task_sdk/docs/test_public_api.py index 98391927f8a4c..a21424ea101f6 100644 --- a/task-sdk/tests/task_sdk/docs/test_public_api.py +++ b/task-sdk/tests/task_sdk/docs/test_public_api.py @@ -65,6 +65,7 @@ def test_airflow_sdk_no_unexpected_exports(): "providers_manager_runtime", "lineage", "types", + "state", } unexpected = actual - public - ignore assert not unexpected, f"Unexpected exports in airflow.sdk: {sorted(unexpected)}" diff --git a/task-sdk/tests/task_sdk/execution_time/test_context.py b/task-sdk/tests/task_sdk/execution_time/test_context.py index a5ff7be9ce865..285f1fd9763e9 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_context.py +++ b/task-sdk/tests/task_sdk/execution_time/test_context.py @@ -25,7 +25,12 @@ import pytest from airflow.sdk import BaseOperator, get_current_context, timezone -from airflow.sdk.api.datamodels._generated import AssetEventResponse, AssetResponse, DagRun +from airflow.sdk._shared.state import TaskScope +from airflow.sdk.api.datamodels._generated import ( + AssetEventResponse, + AssetResponse, + DagRun, +) from airflow.sdk.bases.xcom import BaseXCom from airflow.sdk.definitions.asset import ( Asset, @@ -93,6 +98,7 @@ set_current_context, ) from airflow.sdk.execution_time.secrets import ExecutionAPISecretsBackend +from airflow.sdk.state import BaseStateBackend from tests_common.test_utils.config import conf_vars @@ -1063,11 +1069,12 @@ def get_connection(self, conn_id): class TestTaskStateAccessor: TI_ID = UUID("01900000-0000-0000-0000-000000000001") + SCOPE = TaskScope(dag_id="dag", run_id="run", task_id="task") def test_get_returns_value(self, mock_supervisor_comms): mock_supervisor_comms.send.return_value = TaskStateResult(value="app_001") - result = TaskStateAccessor(ti_id=self.TI_ID).get("job_id") + result = TaskStateAccessor(ti_id=self.TI_ID, scope=self.SCOPE).get("job_id") assert result == "app_001" mock_supervisor_comms.send.assert_called_once_with(GetTaskState(ti_id=self.TI_ID, key="job_id")) @@ -1077,7 +1084,7 @@ def test_get_returns_none_on_404(self, mock_supervisor_comms): error=ErrorType.TASK_STATE_NOT_FOUND, detail={"key": "missing_key"} ) - result = TaskStateAccessor(ti_id=self.TI_ID).get("missing_key") + result = TaskStateAccessor(ti_id=self.TI_ID, scope=self.SCOPE).get("missing_key") assert result is None @@ -1087,7 +1094,7 @@ def test_get_raises_on_error(self, mock_supervisor_comms): ) with pytest.raises(AirflowRuntimeError): - TaskStateAccessor(ti_id=self.TI_ID).get("some_key") + TaskStateAccessor(ti_id=self.TI_ID, scope=self.SCOPE).get("some_key") def test_set_operation_with_global_retention(self, mock_supervisor_comms, time_machine): """set() with no retention uses global default_retention_days config.""" @@ -1097,7 +1104,7 @@ def test_set_operation_with_global_retention(self, mock_supervisor_comms, time_m time_machine.move_to(now, tick=False) with conf_vars({("state_store", "default_retention_days"): "30"}): - TaskStateAccessor(ti_id=self.TI_ID).set("job_id", "app_001") + TaskStateAccessor(ti_id=self.TI_ID, scope=self.SCOPE).set("job_id", "app_001") mock_supervisor_comms.send.assert_called_once_with( SetTaskState( @@ -1114,7 +1121,9 @@ def test_set_with_retention_computes_expires_at(self, mock_supervisor_comms, tim now = datetime(2026, 5, 14, 12, 0, 0, tzinfo=dt_timezone.utc) time_machine.move_to(now, tick=False) - TaskStateAccessor(ti_id=self.TI_ID).set("job_id", "app_001", retention=timedelta(days=7)) + TaskStateAccessor(ti_id=self.TI_ID, scope=self.SCOPE).set( + "job_id", "app_001", retention=timedelta(days=7) + ) mock_supervisor_comms.send.assert_called_once_with( SetTaskState( @@ -1130,7 +1139,7 @@ def test_set_with_never_expire_sends_null_expires_at(self, mock_supervisor_comms mock_supervisor_comms.send.return_value = OKResponse(ok=True) - TaskStateAccessor(ti_id=self.TI_ID).set("job_id", "app_001", retention=NEVER_EXPIRE) + TaskStateAccessor(ti_id=self.TI_ID, scope=self.SCOPE).set("job_id", "app_001", retention=NEVER_EXPIRE) mock_supervisor_comms.send.assert_called_once_with( SetTaskState(ti_id=self.TI_ID, key="job_id", value="app_001", expires_at=None) @@ -1141,7 +1150,7 @@ def test_set_global_default_zero_sends_null_expires_at(self, mock_supervisor_com mock_supervisor_comms.send.return_value = OKResponse(ok=True) with conf_vars({("state_store", "default_retention_days"): "0"}): - TaskStateAccessor(ti_id=self.TI_ID).set("job_id", "app_001") + TaskStateAccessor(ti_id=self.TI_ID, scope=self.SCOPE).set("job_id", "app_001") mock_supervisor_comms.send.assert_called_once_with( SetTaskState(ti_id=self.TI_ID, key="job_id", value="app_001", expires_at=None) @@ -1150,14 +1159,14 @@ def test_set_global_default_zero_sends_null_expires_at(self, mock_supervisor_com def test_delete_operation(self, mock_supervisor_comms): mock_supervisor_comms.send.return_value = OKResponse(ok=True) - TaskStateAccessor(ti_id=self.TI_ID).delete("job_id") + TaskStateAccessor(ti_id=self.TI_ID, scope=self.SCOPE).delete("job_id") mock_supervisor_comms.send.assert_called_once_with(DeleteTaskState(ti_id=self.TI_ID, key="job_id")) def test_clear_default_sends_all_map_indices_false(self, mock_supervisor_comms): mock_supervisor_comms.send.return_value = OKResponse(ok=True) - TaskStateAccessor(ti_id=self.TI_ID).clear() + TaskStateAccessor(ti_id=self.TI_ID, scope=self.SCOPE).clear() mock_supervisor_comms.send.assert_called_once_with( ClearTaskState(ti_id=self.TI_ID, all_map_indices=False) @@ -1166,7 +1175,7 @@ def test_clear_default_sends_all_map_indices_false(self, mock_supervisor_comms): def test_clear_all_map_indices_sends_flag_true(self, mock_supervisor_comms): mock_supervisor_comms.send.return_value = OKResponse(ok=True) - TaskStateAccessor(ti_id=self.TI_ID).clear(all_map_indices=True) + TaskStateAccessor(ti_id=self.TI_ID, scope=self.SCOPE).clear(all_map_indices=True) mock_supervisor_comms.send.assert_called_once_with( ClearTaskState(ti_id=self.TI_ID, all_map_indices=True) @@ -1370,3 +1379,179 @@ def test_alias_inlet_no_resolved_assets_contributes_nothing(self, mock_superviso accessors = AssetStateAccessors([alias]) assert accessors._total == 0 + + +class InMemoryStateBackend(BaseStateBackend): + """Simple in-memory test backend.""" + + def __init__(self): + self._actual_key_value_store: dict[str, str] = {} # key -> actual value + self.reference: dict[str, str] = {} # key -> stored ref (mem:// URI) + + def serialize_task_state_to_ref(self, *, value: str, key: str, ti_id: str) -> str: + ref = f"mem://{ti_id}/{key}" + self._actual_key_value_store[key] = value + self.reference[key] = ref + return ref + + def deserialize_task_state_from_ref(self, stored: str) -> str: + key = stored.rsplit("/", 1)[-1] + return self._actual_key_value_store.get(key, stored) + + def serialize_asset_state_to_ref(self, *, value: str, key: str, asset_ref: str) -> str: + ref = f"mem://{asset_ref}/{key}" + self._actual_key_value_store[key] = value + self.reference[key] = ref + return ref + + def deserialize_asset_state_from_ref(self, stored: str) -> str: + key = stored.rsplit("/", 1)[-1] + return self._actual_key_value_store.get(key, stored) + + def get(self, scope, key, *, session=None): ... + def set(self, scope, key, value, *, session=None): ... + + def delete(self, scope, key, *, session=None) -> None: + self._actual_key_value_store.pop(key, None) + self.reference.pop(key, None) + + def clear(self, scope, *, all_map_indices=False, session=None) -> None: + self._actual_key_value_store.clear() + self.reference.clear() + + async def aget(self, scope, key): ... + async def aset(self, scope, key, value): ... + async def adelete(self, scope, key): ... + async def aclear(self, scope, *, all_map_indices=False): ... + + +class TestTaskStateAccessorWithCustomBackend: + TI_ID = UUID("01900000-0000-0000-0000-000000000002") + SCOPE = TaskScope(dag_id="dag", run_id="run", task_id="task") + + @pytest.fixture(autouse=True) + def backend(self): + b = InMemoryStateBackend() + with mock.patch( + "airflow.sdk.execution_time.context._get_worker_state_backend", + return_value=b, + ): + yield b + + def test_set_returns_reference_to_storage(self, mock_supervisor_comms, backend, time_machine): + """set() stores actual value in backend and sends mem:// reference via comms.""" + mock_supervisor_comms.send.return_value = OKResponse(ok=True) + expected_ref = f"mem://{self.TI_ID}/job_id" + + frozen_dt = datetime(2026, 1, 1, 12, 0, 0, tzinfo=dt_timezone.utc) + time_machine.move_to(frozen_dt, tick=False) + + TaskStateAccessor(ti_id=self.TI_ID, scope=self.SCOPE).set("job_id", "app_001") + # comms message has the mem:// reference, not the actual value + mock_supervisor_comms.send.assert_called_once_with( + SetTaskState( + ti_id=self.TI_ID, key="job_id", value=expected_ref, expires_at=frozen_dt + timedelta(days=30) + ) + ) + # actual value is stored on the backend, reference is stored for DB + assert backend._actual_key_value_store["job_id"] == "app_001" + assert backend.reference["job_id"] == expected_ref + + def test_get_resolves_reference_to_actual_value(self, mock_supervisor_comms, backend): + """get() fetches mem:// reference from DB, resolves it to actual value via backend.""" + ref = f"mem://{self.TI_ID}/job_id" + backend._actual_key_value_store["job_id"] = "app_001" + mock_supervisor_comms.send.return_value = TaskStateResult(value=ref) + + result = TaskStateAccessor(ti_id=self.TI_ID, scope=self.SCOPE).get("job_id") + # actual value is resolved from mem:// reference via backend + assert result == "app_001" + + def test_deletes_from_backend_and_removes_db_ref(self, mock_supervisor_comms, backend): + """delete() purges from backend storage and removes the DB reference.""" + backend._actual_key_value_store["job_id"] = "app_001" + mock_supervisor_comms.send.return_value = OKResponse(ok=True) + + TaskStateAccessor(ti_id=self.TI_ID, scope=self.SCOPE).delete("job_id") + + # backend does not have the value anymore + assert "job_id" not in backend._actual_key_value_store + # request to delete reference in DB was made + mock_supervisor_comms.send.assert_any_call(DeleteTaskState(ti_id=self.TI_ID, key="job_id")) + + def test_clears_all_from_backend_and_clears_db(self, mock_supervisor_comms, backend): + """clear() purges all backend objects for the TI and removes all DB references.""" + backend._actual_key_value_store["job_id"] = "app_001" + backend._actual_key_value_store["checkpoint"] = "step_3" + mock_supervisor_comms.send.return_value = OKResponse(ok=True) + + TaskStateAccessor(ti_id=self.TI_ID, scope=self.SCOPE).clear() + + assert "job_id" not in backend._actual_key_value_store + assert "checkpoint" not in backend._actual_key_value_store + mock_supervisor_comms.send.assert_any_call(ClearTaskState(ti_id=self.TI_ID, all_map_indices=False)) + + +class TestAssetStateAccessorWithCustomBackend: + ASSET_NAME = "my_asset" + + @pytest.fixture(autouse=True) + def backend(self): + b = InMemoryStateBackend() + with mock.patch( + "airflow.sdk.execution_time.context._get_worker_state_backend", + return_value=b, + ): + yield b + + def test_set_sends_reference_not_value(self, mock_supervisor_comms, backend): + """set() stores actual value in backend and sends mem:// reference via comms.""" + mock_supervisor_comms.send.return_value = OKResponse(ok=True) + + AssetStateAccessor(name=self.ASSET_NAME).set("watermark", "2026-05-01") + + expected_ref = f"mem://{self.ASSET_NAME}/watermark" + # comms message has the mem:// reference, not the actual value + mock_supervisor_comms.send.assert_called_once_with( + SetAssetStateByName(name=self.ASSET_NAME, key="watermark", value=expected_ref) + ) + # actual value is stored on the backend, reference is stored for DB + assert backend._actual_key_value_store["watermark"] == "2026-05-01" + assert backend.reference["watermark"] == expected_ref + + def test_get_resolves_reference_to_actual_value(self, mock_supervisor_comms, backend): + """get() fetches mem:// reference from DB, resolves it to actual value via backend.""" + ref = f"mem://{self.ASSET_NAME}/watermark" + backend._actual_key_value_store["watermark"] = "2026-05-01" + mock_supervisor_comms.send.return_value = AssetStateResult(value=ref) + + result = AssetStateAccessor(name=self.ASSET_NAME).get("watermark") + + # actual value is resolved from mem:// reference via backend + assert result == "2026-05-01" + + def test_delete_purges_from_backend_and_removes_db_ref(self, mock_supervisor_comms, backend): + """delete() purges from backend storage and removes the DB reference.""" + backend._actual_key_value_store["watermark"] = "2026-05-01" + mock_supervisor_comms.send.return_value = OKResponse(ok=True) + + AssetStateAccessor(name=self.ASSET_NAME).delete("watermark") + + # backend doesn't have the value anymore + assert "watermark" not in backend._actual_key_value_store + # request to delete reference in DB was made + mock_supervisor_comms.send.assert_any_call( + DeleteAssetStateByName(name=self.ASSET_NAME, key="watermark") + ) + + def test_clear_purges_all_from_backend_and_clears_db(self, mock_supervisor_comms, backend): + """clear() purges all backend objects and removes all DB references.""" + backend._actual_key_value_store["watermark"] = "2026-05-01" + backend._actual_key_value_store["file_count"] = "42" + mock_supervisor_comms.send.return_value = OKResponse(ok=True) + + AssetStateAccessor(name=self.ASSET_NAME).clear() + + assert "watermark" not in backend._actual_key_value_store + assert "file_count" not in backend._actual_key_value_store + mock_supervisor_comms.send.assert_any_call(ClearAssetStateByName(name=self.ASSET_NAME)) diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 54afc412f568b..1558e90ae5e93 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -55,6 +55,7 @@ timezone, ) from airflow.sdk._shared.observability.metrics.base_stats_logger import StatsLogger +from airflow.sdk._shared.state import TaskScope from airflow.sdk.api.datamodels._generated import ( AssetProfile, AssetResponse, @@ -1781,7 +1782,9 @@ def test_get_context_without_ti_context_from_server(self, mocked_parse, make_ti_ "run_id": "test_run", "task": task, "task_instance": runtime_ti, - "task_state": TaskStateAccessor(ti_id=ti_id), + "task_state": TaskStateAccessor( + ti_id=ti_id, scope=TaskScope(dag_id=dag_id, run_id="test_run", task_id="hello") + ), "ti": runtime_ti, } @@ -1827,7 +1830,10 @@ def test_get_context_with_ti_context_from_server(self, create_runtime_ti, mock_s "run_id": "test_run", "task": task, "task_instance": runtime_ti, - "task_state": TaskStateAccessor(ti_id=runtime_ti.id), + "task_state": TaskStateAccessor( + ti_id=runtime_ti.id, + scope=TaskScope(dag_id=runtime_ti.dag_id, run_id="test_run", task_id="hello"), + ), "ti": runtime_ti, "dag_run": dr, "data_interval_end": timezone.datetime(2024, 12, 1, 1, 0, 0), @@ -5207,3 +5213,91 @@ def execute(self, context): mock_supervisor_comms.send.assert_any_call( SetAssetStateByName(name="asset_b", key="watermark_b", value="2026-05-02") ) + + def test_asset_state_set_sends_reference_via_custom_backend( + self, create_runtime_ti, mock_supervisor_comms + ): + """When a worker backend is configured, asset state set() sends a reference, not the actual value.""" + watched = Asset(name="my_asset", uri="s3://bucket/data") + + class WatcherOperator(BaseOperator): + def execute(self, context): + context["asset_state"].set("watermark", "2026-05-01") + + task = WatcherOperator(task_id="t", inlets=[watched]) + runtime_ti = create_runtime_ti(task=task) + mock_supervisor_comms.send.side_effect = TestTaskInstanceStateOperations._watcher_side_effect + + mock_backend = mock.MagicMock() + mock_backend.serialize_asset_state_to_ref.return_value = "mem://my_asset/watermark" + + with mock.patch( + "airflow.sdk.execution_time.context._get_worker_state_backend", return_value=mock_backend + ): + run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) + + mock_backend.serialize_asset_state_to_ref.assert_called_once_with( + value="2026-05-01", key="watermark", asset_ref="my_asset" + ) + mock_supervisor_comms.send.assert_any_call( + SetAssetStateByName(name="my_asset", key="watermark", value="mem://my_asset/watermark") + ) + + def test_task_state_set_sends_reference_via_custom_backend( + self, create_runtime_ti, mock_supervisor_comms, time_machine + ): + """When a worker backend is configured, task state set() sends a reference, not the actual value.""" + + class MyOperator(BaseOperator): + def execute(self, context): + context["task_state"].set("job_id", "app_001") + + frozen_dt = datetime(2026, 1, 1, 12, 0, 0, tzinfo=dt_timezone.utc) + time_machine.move_to(frozen_dt, tick=False) + task = MyOperator(task_id="t") + runtime_ti = create_runtime_ti(task=task) + mock_supervisor_comms.send.side_effect = TestTaskInstanceStateOperations._watcher_side_effect + + mock_backend = mock.MagicMock() + ref = f"mem://{runtime_ti.id}/job_id" + mock_backend.serialize_task_state_to_ref.return_value = ref + + with mock.patch( + "airflow.sdk.execution_time.context._get_worker_state_backend", return_value=mock_backend + ): + run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) + + mock_backend.serialize_task_state_to_ref.assert_called_once_with( + value="app_001", key="job_id", ti_id=str(runtime_ti.id) + ) + mock_supervisor_comms.send.assert_any_call( + SetTaskState( + ti_id=runtime_ti.id, key="job_id", value=ref, expires_at=frozen_dt + timedelta(days=30) + ) + ) + + @conf_vars({("state_store", "clear_on_success"): "True"}) + def test_clear_on_success_clears_backend_without_comms_roundtrip( + self, create_runtime_ti, mock_supervisor_comms + ): + """clear_on_success calls backend.clear() directly without sending ClearTaskState comms.""" + mock_backend = mock.MagicMock() + + class MyOperator(BaseOperator): + def execute(self, context): + pass + + task = MyOperator(task_id="t") + runtime_ti = create_runtime_ti(task=task) + + with mock.patch( + "airflow.sdk.execution_time.context._get_worker_state_backend", return_value=mock_backend + ): + run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) + + mock_backend.clear.assert_called_once() + sent_types = [ + type(call.kwargs.get("msg") or (call.args[0] if call.args else None)) + for call in mock_supervisor_comms.send.call_args_list + ] + assert ClearTaskState not in sent_types