From 68d02628f382289804a4212d5ce49e8e93b5a011 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Wed, 13 May 2026 14:43:31 +0530 Subject: [PATCH 01/20] execution API endpoints and datamodel --- .../execution_api/datamodels/asset_state.py | 13 ++++++ .../execution_api/datamodels/task_state.py | 13 ++++++ .../execution_api/routes/asset_state.py | 25 ++++++++++ .../execution_api/routes/task_state.py | 22 +++++++++ .../execution_api/versions/v2026_06_16.py | 3 ++ .../src/airflow/config_templates/config.yml | 11 +++++ .../versions/head/test_task_state.py | 28 +++++++++++ .../airflow/sdk/api/datamodels/_generated.py | 46 +++++++++++++++++++ 8 files changed, 161 insertions(+) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/asset_state.py b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/asset_state.py index ec773201c7e2f..35852686b93fd 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/asset_state.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/asset_state.py @@ -30,3 +30,16 @@ class AssetStatePutBody(StrictBaseModel): """Request body for setting an asset state value.""" value: str + + +class AssetStateItem(StrictBaseModel): + """Asset state key/value pair returned by the list endpoint.""" + + key: str + value: str + + +class AssetStateListResponse(StrictBaseModel): + """All asset state entries for an asset.""" + + items: list[AssetStateItem] diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/task_state.py b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/task_state.py index 3200f3177af35..93b5eadbbfcd8 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/task_state.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/task_state.py @@ -30,3 +30,16 @@ class TaskStatePutBody(StrictBaseModel): """Request body for setting a task state value.""" value: str + + +class TaskStateItem(StrictBaseModel): + """Task state key/value pair returned by the list endpoint.""" + + key: str + value: str + + +class TaskStateListResponse(StrictBaseModel): + """All task state entries for a task instance.""" + + items: list[TaskStateItem] diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/asset_state.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/asset_state.py index 3ff321def8720..d510603937800 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/asset_state.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/asset_state.py @@ -37,11 +37,14 @@ from airflow._shared.state import AssetScope from airflow.api_fastapi.common.db.common import SessionDep from airflow.api_fastapi.execution_api.datamodels.asset_state import ( + AssetStateItem, + AssetStateListResponse, AssetStatePutBody, AssetStateResponse, ) from airflow.api_fastapi.execution_api.security import ExecutionAPIRoute from airflow.models.asset import AssetModel +from airflow.models.asset_state import AssetStateModel from airflow.state import get_state_backend # TODO(AIP-103): enforce that the requesting task is registered with the asset @@ -177,3 +180,25 @@ def clear_asset_state_by_uri( """Delete all state keys for an asset by asset URI.""" asset_id = _resolve_asset_id_by_uri(uri, session) get_state_backend().clear(AssetScope(asset_id=asset_id), session=session) # type: ignore[call-arg] # @provide_session adds session kwarg at runtime; BaseStateBackend signature omits it so mypy can't see it + + +@router.get("/by-name/all") +def list_asset_state_by_name( + name: Annotated[str, Query(min_length=1)], + session: SessionDep, +) -> AssetStateListResponse: + """List all key/value pairs for an asset identified by name.""" + asset_id = _resolve_asset_id_by_name(name, session) + rows = session.scalars(select(AssetStateModel).where(AssetStateModel.asset_id == asset_id)).all() + return AssetStateListResponse(items=[AssetStateItem(key=r.key, value=r.value) for r in rows]) + + +@router.get("/by-uri/all") +def list_asset_state_by_uri( + uri: Annotated[str, Query(min_length=1)], + session: SessionDep, +) -> AssetStateListResponse: + """List all key/value pairs for an asset identified by URI.""" + asset_id = _resolve_asset_id_by_uri(uri, session) + rows = session.scalars(select(AssetStateModel).where(AssetStateModel.asset_id == asset_id)).all() + return AssetStateListResponse(items=[AssetStateItem(key=r.key, value=r.value) for r in rows]) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_state.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_state.py index db24109969c76..780534d19f13e 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_state.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_state.py @@ -21,15 +21,19 @@ from cadwyn import VersionedAPIRouter from fastapi import HTTPException, Path, Query, Security, status +from sqlalchemy import select from sqlalchemy.orm import Session from airflow._shared.state import TaskScope from airflow.api_fastapi.common.db.common import SessionDep from airflow.api_fastapi.execution_api.datamodels.task_state import ( + TaskStateItem, + TaskStateListResponse, TaskStatePutBody, TaskStateResponse, ) from airflow.api_fastapi.execution_api.security import ExecutionAPIRoute, require_auth +from airflow.models.task_state import TaskStateModel from airflow.models.taskinstance import TaskInstance as TI from airflow.state import get_state_backend @@ -126,3 +130,21 @@ def clear_task_state( """ scope = _get_task_scope_for_ti(task_instance_id, session) get_state_backend().clear(scope, all_map_indices=all_map_indices, session=session) # type: ignore[call-arg] # @provide_session adds session kwarg at runtime; BaseStateBackend signature omits it so mypy can't see it + + +@router.get("/{task_instance_id}") +def list_task_state( + task_instance_id: UUID, + session: SessionDep, +) -> TaskStateListResponse: + """List all key/value pairs for a task instance.""" + scope = _get_task_scope_for_ti(task_instance_id, session) + rows = session.scalars( + select(TaskStateModel).where( + TaskStateModel.dag_id == scope.dag_id, + TaskStateModel.run_id == scope.run_id, + TaskStateModel.task_id == scope.task_id, + TaskStateModel.map_index == scope.map_index, + ) + ).all() + return TaskStateListResponse(items=[TaskStateItem(key=r.key, value=r.value) for r in rows]) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_06_16.py b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_06_16.py index 779612bbde134..3f24e1b8f8622 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_06_16.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_06_16.py @@ -65,14 +65,17 @@ class AddStateEndpoints(VersionChange): description = __doc__ instructions_to_migrate_to_previous_version = ( + endpoint("/state/ti/{task_instance_id}", ["GET"]).didnt_exist, endpoint("/state/ti/{task_instance_id}/{key}", ["GET"]).didnt_exist, endpoint("/state/ti/{task_instance_id}/{key}", ["PUT"]).didnt_exist, endpoint("/state/ti/{task_instance_id}/{key}", ["DELETE"]).didnt_exist, endpoint("/state/ti/{task_instance_id}", ["DELETE"]).didnt_exist, + endpoint("/state/asset/by-name/all", ["GET"]).didnt_exist, endpoint("/state/asset/by-name/value", ["GET"]).didnt_exist, endpoint("/state/asset/by-name/value", ["PUT"]).didnt_exist, endpoint("/state/asset/by-name/value", ["DELETE"]).didnt_exist, endpoint("/state/asset/by-name/clear", ["DELETE"]).didnt_exist, + endpoint("/state/asset/by-uri/all", ["GET"]).didnt_exist, endpoint("/state/asset/by-uri/value", ["GET"]).didnt_exist, endpoint("/state/asset/by-uri/value", ["PUT"]).didnt_exist, endpoint("/state/asset/by-uri/value", ["DELETE"]).didnt_exist, diff --git a/airflow-core/src/airflow/config_templates/config.yml b/airflow-core/src/airflow/config_templates/config.yml index 8d5d6e5fd2611..015737df4cb95 100644 --- a/airflow-core/src/airflow/config_templates/config.yml +++ b/airflow-core/src/airflow/config_templates/config.yml @@ -1886,6 +1886,17 @@ workers: sensitive: true example: ~ default: "" + state_backend: + description: | + Full class name of the state backend to use on workers for direct task state access, + bypassing the execution API. When set, ``TaskStateAccessor`` calls this backend directly + instead of routing through the supervisor comms path. + + Leave empty (default) to use the standard comms path through the supervisor. + version_added: 3.3.0 + type: string + example: "mypackage.state.S3StateBackend" + default: "" min_heartbeat_interval: description: | The minimum interval (in seconds) at which the worker checks the task instance's diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_state.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_state.py index 8a66a0a23c739..f5abf1f9b8f13 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_state.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_state.py @@ -51,6 +51,34 @@ def _api_url(ti_id, key: str | None = None) -> str: return f"{base}/{key}" if key else base +class TestListTaskState: + def test_list_returns_all_keys(self, client: TestClient, create_task_instance: CreateTaskInstance): + ti = create_task_instance() + client.put(_api_url(ti.id, "job_id"), json={"value": "spark_001"}) + client.put(_api_url(ti.id, "checkpoint"), json={"value": "step_3"}) + + response = client.get(_api_url(ti.id)) + + assert response.status_code == 200 + items = {item["key"]: item["value"] for item in response.json()["items"]} + assert items == {"job_id": "spark_001", "checkpoint": "step_3"} + + def test_list_returns_empty_when_no_state( + self, client: TestClient, create_task_instance: CreateTaskInstance + ): + ti = create_task_instance() + + response = client.get(_api_url(ti.id)) + + assert response.status_code == 200 + assert response.json() == {"items": []} + + def test_list_missing_ti_returns_404(self, client: TestClient): + response = client.get(_api_url(uuid4())) + + assert response.status_code == 404 + + class TestGetTaskState: def test_get_returns_value(self, client: TestClient, create_task_instance: CreateTaskInstance): ti = create_task_instance() diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py index b5b100154c389..df37e3771397e 100644 --- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -63,6 +63,29 @@ class AssetProfile(BaseModel): type: Annotated[str, Field(title="Type")] +class AssetStateItem(BaseModel): + """ + Asset state key/value pair returned by the list endpoint. + """ + + model_config = ConfigDict( + extra="forbid", + ) + key: Annotated[str, Field(title="Key")] + value: Annotated[str, Field(title="Value")] + + +class AssetStateListResponse(BaseModel): + """ + All asset state entries for an asset. + """ + + model_config = ConfigDict( + extra="forbid", + ) + items: Annotated[list[AssetStateItem], Field(title="Items")] + + class AssetStatePutBody(BaseModel): """ Request body for setting an asset state value. @@ -367,6 +390,29 @@ class TaskInstanceState(str, Enum): DEFERRED = "deferred" +class TaskStateItem(BaseModel): + """ + Task state key/value pair returned by the list endpoint. + """ + + model_config = ConfigDict( + extra="forbid", + ) + key: Annotated[str, Field(title="Key")] + value: Annotated[str, Field(title="Value")] + + +class TaskStateListResponse(BaseModel): + """ + All task state entries for a task instance. + """ + + model_config = ConfigDict( + extra="forbid", + ) + items: Annotated[list[TaskStateItem], Field(title="Items")] + + class TaskStatePutBody(BaseModel): """ Request body for setting a task state value. From e2e16ea7e29dfea4e365b765f9921b118eeb5643 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Wed, 13 May 2026 14:50:37 +0530 Subject: [PATCH 02/20] shared lib changes --- .../src/airflow_shared/state/__init__.py | 56 +++++++++++++++ shared/state/tests/state/test_state.py | 68 +++++++++++++++++++ 2 files changed, 124 insertions(+) diff --git a/shared/state/src/airflow_shared/state/__init__.py b/shared/state/src/airflow_shared/state/__init__.py index 463d9f378f315..92fd99a977f36 100644 --- a/shared/state/src/airflow_shared/state/__init__.py +++ b/shared/state/src/airflow_shared/state/__init__.py @@ -122,3 +122,59 @@ async def aclear(self, scope: StateScope, *, all_map_indices: bool = False) -> N scope are cleared. Pass ``all_map_indices=True`` to wipe state across every mapped instance of the task. For ``AssetScope`` the flag has no effect. """ + + def serialize_task_state_value(self, *, value: str, key: str, ti_id: str) -> str: + """ + Serialize a task state value before it is sent to the execution API for db persistence. + + Called by ``TaskStateAccessor.set()`` on the worker. The return value is what gets + stored in the DB — typically a reference path (e.g. an S3 key) rather than the + actual value. Default: return ``value`` unchanged. + """ + return value + + def deserialize_task_state_value(self, stored: str) -> str: + """ + Resolve a stored task state string back to the actual value. + + Called by ``TaskStateAccessor.get()`` after the stored string is retrieved from + the execution API. Default: return ``stored`` unchanged. + """ + return stored + + def serialize_asset_state_value(self, *, value: str, key: str, asset_name: str) -> str: + """ + Serialize an asset state value before it is sent to the Execution API for db persistence. + + Called by ``AssetStateAccessor.set()`` on the worker. The return value is what gets + stored in the DB — typically a reference path rather than the actual value. + Default: return ``value`` unchanged. + """ + return value + + def deserialize_asset_state_value(self, stored: str) -> str: + """ + Resolve a stored asset state string back to the actual value. + + Called by ``AssetStateAccessor.get()`` after the stored string is retrieved from + the Execution API. Default: return ``stored`` unchanged. + """ + return stored + + def purge_task_state(self, stored: str) -> None: + """ + Clean up the task state storage object on the custom backend identified by ``stored``. + + Called by ``TaskStateAccessor.delete()`` and ``TaskStateAccessor.clear()`` before + the DB reference is removed. ``stored`` is whatever ``serialize_task_state_value`` + returned. Default: no-op. + """ + + def purge_asset_state(self, stored: str) -> None: + """ + Clean up the asset state storage object on the custom backend identified by ``stored``. + + Called by ``AssetStateAccessor.delete()`` and ``AssetStateAccessor.clear()`` before + the DB reference is removed. ``stored`` is whatever ``serialize_asset_state_value`` + returned. Default: no-op. + """ diff --git a/shared/state/tests/state/test_state.py b/shared/state/tests/state/test_state.py index 47bce18a69eab..b3e11b2833d3b 100644 --- a/shared/state/tests/state/test_state.py +++ b/shared/state/tests/state/test_state.py @@ -70,3 +70,71 @@ def test_abstract_methods_cover_full_interface(self): """BaseStateBackend enforces all 8 sync+async methods as abstract.""" expected = {"get", "set", "delete", "clear", "aget", "aset", "adelete", "aclear"} assert BaseStateBackend.__abstractmethods__ == expected + + def test_task_state_serialize_deserialize_round_trip(self, backend): + original = "app_1234" + serialized = backend.serialize_task_state_value(value=original, key="job_id", ti_id="abc-123") + deserialized = backend.deserialize_task_state_value(serialized) + assert deserialized == original + + def test_custom_backend_overrides_task_state_ser_deser(self): + class MyBackend(BaseStateBackend): + def get(self, scope, key): ... + def set(self, scope, key, value): ... + def delete(self, scope, key): ... + def clear(self, scope, *, all_map_indices=False): ... + async def aget(self, scope, key): ... + async def aset(self, scope, key, value): ... + async def adelete(self, scope, key): ... + async def aclear(self, scope, *, all_map_indices=False): ... + + def serialize_task_state_value(self, *, value, key, ti_id): + return f"s3://bucket/{ti_id}/{key}" + + def deserialize_task_state_value(self, stored): + return f"fetched:{stored}" + + b = MyBackend() + assert ( + b.serialize_task_state_value(value="app_1234", key="job_id", ti_id="abc-123") + == "s3://bucket/abc-123/job_id" + ) + assert ( + b.deserialize_task_state_value("s3://bucket/abc-123/job_id") + == "fetched:s3://bucket/abc-123/job_id" + ) + + def test_asset_state_serialize_deserialize_round_trip(self, backend): + original = "2026-05-01" + serialized = backend.serialize_asset_state_value( + value="2026-05-01", key="watermark", asset_name="my_asset" + ) + deserialized = backend.deserialize_asset_state_value(serialized) + assert deserialized == original + + def test_custom_backend_overrides_asset_state_ser_deser(self): + class MyBackend(BaseStateBackend): + def get(self, scope, key): ... + def set(self, scope, key, value): ... + def delete(self, scope, key): ... + def clear(self, scope, *, all_map_indices=False): ... + async def aget(self, scope, key): ... + async def aset(self, scope, key, value): ... + async def adelete(self, scope, key): ... + async def aclear(self, scope, *, all_map_indices=False): ... + + def serialize_asset_state_value(self, *, value, key, asset_name): + return f"s3://bucket/assets/{asset_name}/{key}" + + def deserialize_asset_state_value(self, stored): + return f"resolved:{stored}" + + b = MyBackend() + assert ( + b.serialize_asset_state_value(value="2026-05-01", key="watermark", asset_name="my_asset") + == "s3://bucket/assets/my_asset/watermark" + ) + assert ( + b.deserialize_asset_state_value("s3://bucket/assets/my_asset/watermark") + == "resolved:s3://bucket/assets/my_asset/watermark" + ) From 36de8cb724667f25b55386f2f43cd0baf58cdbd7 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Wed, 13 May 2026 15:17:27 +0530 Subject: [PATCH 03/20] task sdk: api client changes --- task-sdk/src/airflow/sdk/api/client.py | 34 ++++++++- task-sdk/src/airflow/sdk/exceptions.py | 2 + .../src/airflow/sdk/execution_time/comms.py | 50 ++++++++++++++ task-sdk/tests/task_sdk/api/test_client.py | 69 +++++++++++++++++++ 4 files changed, 153 insertions(+), 2 deletions(-) diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index 493225b4699bd..cf5dc8fd0cf12 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -46,6 +46,7 @@ API_VERSION, AssetEventsResponse, AssetResponse, + AssetStateListResponse, AssetStatePutBody, AssetStateResponse, ConnectionResponse, @@ -60,6 +61,7 @@ PrevSuccessfulDagRunResponse, TaskBreadcrumbsResponse, TaskInstanceState, + TaskStateListResponse, TaskStatePutBody, TaskStateResponse, TaskStatesResponse, @@ -91,6 +93,7 @@ OKResponse, PreviousDagRunResult, PreviousTIResult, + RescheduleTask, SkipDownstreamTasks, TaskRescheduleStartDate, TICount, @@ -102,8 +105,6 @@ from datetime import datetime from typing import ParamSpec - from airflow.sdk.execution_time.comms import RescheduleTask - P = ParamSpec("P") T = TypeVar("T") @@ -695,6 +696,18 @@ def delete(self, ti_id: uuid.UUID, key: str) -> OKResponse: self.client.delete(f"state/ti/{ti_id}/{key}") return OKResponse(ok=True) + def list(self, ti_id: uuid.UUID) -> TaskStateListResponse | ErrorResponse: + """Return all key/stored-value pairs for a task instance.""" + try: + resp = self.client.get(f"state/ti/{ti_id}") + except ServerResponseError as e: + if e.response.status_code == HTTPStatus.NOT_FOUND: + log.debug("Task states cannot be retrieved for task instance", ti_id=ti_id) + return ErrorResponse(error=ErrorType.TASK_STATES_NOT_FOUND, detail={"ti_id": ti_id}) + raise + + return TaskStateListResponse.model_validate_json(resp.read()) + def clear(self, ti_id: uuid.UUID, all_map_indices: bool = False) -> OKResponse: """Clear all task state keys for a task instance via the API server.""" params = {"all_map_indices": "true"} if all_map_indices else {} @@ -749,6 +762,23 @@ def delete(self, key: str, *, name: str | None = None, uri: str | None = None) - self.client.delete(endpoint, params=params) return OKResponse(ok=True) + def list( + self, *, name: str | None = None, uri: str | None = None + ) -> AssetStateListResponse | ErrorResponse: + """Return all key/stored-value pairs for an asset via the API server.""" + endpoint, params = self._resolve_endpoint("all", name=name, uri=uri) + try: + resp = self.client.get(endpoint, params=params) + except ServerResponseError as e: + if e.response.status_code == HTTPStatus.NOT_FOUND: + log.debug("Asset state cannot be retrieved for asset", name=name, uri=uri) + return ErrorResponse( + error=ErrorType.ASSET_STATES_NOT_FOUND, detail={"name": name, "uri": uri} + ) + raise + + return AssetStateListResponse.model_validate_json(resp.read()) + def clear(self, *, name: str | None = None, uri: str | None = None) -> OKResponse: """Clear all state keys for an asset via the API server.""" endpoint, params = self._resolve_endpoint("clear", name=name, uri=uri) diff --git a/task-sdk/src/airflow/sdk/exceptions.py b/task-sdk/src/airflow/sdk/exceptions.py index ed3bb3f14938e..5b1f822675627 100644 --- a/task-sdk/src/airflow/sdk/exceptions.py +++ b/task-sdk/src/airflow/sdk/exceptions.py @@ -81,7 +81,9 @@ class ErrorType(enum.Enum): XCOM_NOT_FOUND = "XCOM_NOT_FOUND" ASSET_NOT_FOUND = "ASSET_NOT_FOUND" TASK_STATE_NOT_FOUND = "TASK_STATE_NOT_FOUND" + TASK_STATES_NOT_FOUND = "TASK_STATES_NOT_FOUND" ASSET_STATE_NOT_FOUND = "ASSET_STATE_NOT_FOUND" + ASSET_STATES_NOT_FOUND = "ASSET_STATE_NOT_FOUND" DAGRUN_ALREADY_EXISTS = "DAGRUN_ALREADY_EXISTS" GENERIC_ERROR = "GENERIC_ERROR" API_SERVER_ERROR = "API_SERVER_ERROR" diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index 01528c728b1fe..b6b2c6b9bbf4f 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -69,6 +69,8 @@ AssetEventResponse, AssetEventsResponse, AssetResponse, + AssetStateItem, + AssetStateListResponse, AssetStateResponse, BundleInfo, ConnectionResponse, @@ -82,6 +84,8 @@ TaskBreadcrumbsResponse, TaskInstance, TaskInstanceState, + TaskStateItem, + TaskStateListResponse, TaskStateResponse, TaskStatesResponse, TIDeferredStatePayload, @@ -762,6 +766,26 @@ def from_api_response(cls, dag_response: DagResponse) -> DagResult: return cls(**dag_response.model_dump(exclude_defaults=True), type="DagResult") +class AllTaskStateResult(TaskStateListResponse): + """Response to GetAllTaskState: all key/value pairs for a task instance.""" + + type: Literal["AllTaskStateResult"] = "AllTaskStateResult" + + @classmethod + def from_api_response(cls, items: list[TaskStateItem]) -> AllTaskStateResult: + return cls(items=items, type="AllTaskStateResult") + + +class AllAssetStateResult(AssetStateListResponse): + """Response to GetAllAssetStateByName/Uri: all key/value pairs for an asset.""" + + type: Literal["AllAssetStateResult"] = "AllAssetStateResult" + + @classmethod + def from_api_response(cls, items: list[AssetStateItem]) -> AllAssetStateResult: + return cls(items=items, type="AllAssetStateResult") + + ToTask = Annotated[ AssetResult | AssetsByAliasResult @@ -778,6 +802,8 @@ def from_api_response(cls, dag_response: DagResponse) -> DagResult: | SentFDs | StartupDetails | TaskRescheduleStartDate + | AllAssetStateResult + | AllTaskStateResult | TaskStateResult | TICount | TaskBreadcrumbsResult @@ -932,6 +958,13 @@ class ClearTaskState(BaseModel): type: Literal["ClearTaskState"] = "ClearTaskState" +class GetAllTaskState(BaseModel): + """Fetch all key/stored-value pairs for a task instance.""" + + ti_id: UUID + type: Literal["GetAllTaskState"] = "GetAllTaskState" + + class GetAssetStateByName(BaseModel): name: str key: str @@ -958,6 +991,20 @@ class SetAssetStateByUri(BaseModel): type: Literal["SetAssetStateByUri"] = "SetAssetStateByUri" +class GetAllAssetStateByName(BaseModel): + """Fetch all key/value pairs for an asset identified by name.""" + + name: str + type: Literal["GetAllAssetStateByName"] = "GetAllAssetStateByName" + + +class GetAllAssetStateByUri(BaseModel): + """Fetch all key/value pairs for an asset identified by URI.""" + + uri: str + type: Literal["GetAllAssetStateByUri"] = "GetAllAssetStateByUri" + + class DeleteAssetStateByName(BaseModel): name: str key: str @@ -1199,6 +1246,9 @@ class GetDag(BaseModel): | GetPreviousDagRun | GetPreviousTI | GetTaskRescheduleStartDate + | GetAllAssetStateByName + | GetAllAssetStateByUri + | GetAllTaskState | GetTaskState | GetTICount | GetTaskBreadcrumbs diff --git a/task-sdk/tests/task_sdk/api/test_client.py b/task-sdk/tests/task_sdk/api/test_client.py index a179ff08436b2..eab824b40f268 100644 --- a/task-sdk/tests/task_sdk/api/test_client.py +++ b/task-sdk/tests/task_sdk/api/test_client.py @@ -36,6 +36,8 @@ from airflow.sdk.api.datamodels._generated import ( AssetEventsResponse, AssetResponse, + AssetStateItem, + AssetStateListResponse, AssetStateResponse, ConnectionResponse, DagResponse, @@ -44,6 +46,8 @@ HITLDetailRequest, HITLDetailResponse, HITLUser, + TaskStateItem, + TaskStateListResponse, TaskStateResponse, TerminalTIState, VariableResponse, @@ -1804,6 +1808,34 @@ def handle_request(request: httpx.Request) -> httpx.Response: result = client.task_state.clear(ti_id=self.TI_ID, all_map_indices=True) assert result == OKResponse(ok=True) + def test_list_returns_key_value_pairs(self): + def handle_request(request: httpx.Request) -> httpx.Response: + assert request.method == "GET" + assert request.url.path == f"/state/ti/{self.TI_ID}" + return httpx.Response( + status_code=200, + json={ + "items": [{"key": "job_id", "value": "app_001"}, {"key": "checkpoint", "value": "step_3"}] + }, + ) + + client = make_client(transport=httpx.MockTransport(handle_request)) + result = client.task_state.list(ti_id=self.TI_ID) + assert isinstance(result, TaskStateListResponse) + assert result.items == [ + TaskStateItem(key="job_id", value="app_001"), + TaskStateItem(key="checkpoint", value="step_3"), + ] + + def test_list_returns_empty_when_no_state(self): + def handle_request(request: httpx.Request) -> httpx.Response: + return httpx.Response(status_code=200, json={"items": []}) + + client = make_client(transport=httpx.MockTransport(handle_request)) + result = client.task_state.list(ti_id=self.TI_ID) + assert isinstance(result, TaskStateListResponse) + assert result.items == [] + class TestAssetStateOperations: def test_get_by_name_success(self): @@ -1920,3 +1952,40 @@ def handle_request(request: httpx.Request) -> httpx.Response: client = make_client(transport=httpx.MockTransport(handle_request)) result = client.asset_state.clear(uri="s3://bucket/key") assert result == OKResponse(ok=True) + + def test_list_all_by_name_returns_items(self): + def handle_request(request: httpx.Request) -> httpx.Response: + assert request.method == "GET" + assert request.url.path == "/state/asset/by-name/all" + assert request.url.params["name"] == "test_asset" + return httpx.Response( + status_code=200, + json={ + "items": [ + {"key": "watermark", "value": "2026-05-01"}, + {"key": "file_count", "value": "42"}, + ] + }, + ) + + client = make_client(transport=httpx.MockTransport(handle_request)) + result = client.asset_state.list(name="test_asset") + assert isinstance(result, AssetStateListResponse) + assert result.items == [ + AssetStateItem(key="watermark", value="2026-05-01"), + AssetStateItem(key="file_count", value="42"), + ] + + def test_list_all_by_uri_returns_items(self): + def handle_request(request: httpx.Request) -> httpx.Response: + assert request.method == "GET" + assert request.url.path == "/state/asset/by-uri/all" + assert request.url.params["uri"] == "s3://bucket/key" + return httpx.Response( + status_code=200, json={"items": [{"key": "watermark", "value": "2026-05-01"}]} + ) + + client = make_client(transport=httpx.MockTransport(handle_request)) + result = client.asset_state.list(uri="s3://bucket/key") + assert isinstance(result, AssetStateListResponse) + assert result.items == [AssetStateItem(key="watermark", value="2026-05-01")] From bc94b1c0fa70699d4b2ff60ffad15c73efe000bd Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Wed, 13 May 2026 15:33:02 +0530 Subject: [PATCH 04/20] task sdk: supervisor changes --- .../airflow/sdk/execution_time/supervisor.py | 26 +++++++++ .../execution_time/test_supervisor.py | 58 +++++++++++++++++++ 2 files changed, 84 insertions(+) diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 757a73b7e2edc..dae80636181c9 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -62,6 +62,8 @@ from airflow.sdk.exceptions import ErrorType from airflow.sdk.execution_time import comms from airflow.sdk.execution_time.comms import ( + AllAssetStateResult, + AllTaskStateResult, AssetEventsResult, AssetResult, AssetStateResult, @@ -80,6 +82,9 @@ DeleteVariable, DeleteXCom, ErrorResponse, + GetAllAssetStateByName, + GetAllAssetStateByUri, + GetAllTaskState, GetAssetByName, GetAssetByUri, GetAssetEventByAsset, @@ -1655,6 +1660,27 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id: if isinstance(task_state, ErrorResponse) else TaskStateResult.from_task_state_response(task_state) ) + elif isinstance(msg, GetAllTaskState): + result = self.client.task_state.list(msg.ti_id) + resp = ( + result + if isinstance(result, ErrorResponse) + else AllTaskStateResult.from_api_response(result.items) + ) + elif isinstance(msg, GetAllAssetStateByName): + result = self.client.asset_state.list(name=msg.name) + resp = ( + result + if isinstance(result, ErrorResponse) + else AllAssetStateResult.from_api_response(result.items) + ) + elif isinstance(msg, GetAllAssetStateByUri): + result = self.client.asset_state.list(uri=msg.uri) + resp = ( + result + if isinstance(result, ErrorResponse) + else AllAssetStateResult.from_api_response(result.items) + ) elif isinstance(msg, SetTaskState): self.client.task_state.set(msg.ti_id, msg.key, msg.value) resp = OKResponse(ok=True) diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index a8f97f81ac266..a0a43728b9eb0 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -60,16 +60,20 @@ AssetEventResponse, AssetProfile, AssetResponse, + AssetStateItem, DagRun, DagRunState, DagRunType, PreviousTIResponse, TaskInstance, TaskInstanceState, + TaskStateItem, ) from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType, TaskAlreadyRunningError from airflow.sdk.execution_time import supervisor, task_runner from airflow.sdk.execution_time.comms import ( + AllAssetStateResult, + AllTaskStateResult, AssetEventsResult, AssetResult, AssetsByAliasResult, @@ -91,6 +95,9 @@ DeleteXCom, DRCount, ErrorResponse, + GetAllAssetStateByName, + GetAllAssetStateByUri, + GetAllTaskState, GetAssetByName, GetAssetByUri, GetAssetEventByAsset, @@ -2705,6 +2712,24 @@ class RequestTestCase: ), expected_body={"value": "spark_app_001", "type": "TaskStateResult"}, ), + RequestTestCase( + message=GetAllTaskState(ti_id=TI_ID), + test_id="get_all_task_state", + client_mock=ClientMock( + method_path="task_state.list", + args=(TI_ID,), + response=AllTaskStateResult.from_api_response( + [ + TaskStateItem(key="job_id", value="app_001"), + TaskStateItem(key="checkpoint", value="step_3"), + ] + ), + ), + expected_body={ + "items": [{"key": "job_id", "value": "app_001"}, {"key": "checkpoint", "value": "step_3"}], + "type": "AllTaskStateResult", + }, + ), RequestTestCase( message=SetTaskState(ti_id=TI_ID, key="job_id", value="spark_app_001"), test_id="set_task_state", @@ -2769,6 +2794,39 @@ class RequestTestCase: ), expected_body={"value": "2026-04-30T00:00:00Z", "type": "AssetStateResult"}, ), + RequestTestCase( + message=GetAllAssetStateByName(name="debug_watcher_asset"), + test_id="get_all_asset_state_by_name", + client_mock=ClientMock( + method_path="asset_state.list", + kwargs={"name": "debug_watcher_asset"}, + response=AllAssetStateResult.from_api_response( + [ + AssetStateItem(key="watermark", value="2026-05-01"), + AssetStateItem(key="file_count", value="42"), + ] + ), + ), + expected_body={ + "items": [{"key": "watermark", "value": "2026-05-01"}, {"key": "file_count", "value": "42"}], + "type": "AllAssetStateResult", + }, + ), + RequestTestCase( + message=GetAllAssetStateByUri(uri="s3://bucket/key"), + test_id="get_all_asset_state_by_uri", + client_mock=ClientMock( + method_path="asset_state.list", + kwargs={"uri": "s3://bucket/key"}, + response=AllAssetStateResult.from_api_response( + [AssetStateItem(key="watermark", value="2026-05-01")] + ), + ), + expected_body={ + "items": [{"key": "watermark", "value": "2026-05-01"}], + "type": "AllAssetStateResult", + }, + ), RequestTestCase( message=SetAssetStateByName( name="debug_watcher_asset", key="watermark", value="2026-04-30T00:00:00Z" From 8b19c09ce11a7b2d07024de4bfb75e823117b804 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Wed, 13 May 2026 15:40:49 +0530 Subject: [PATCH 05/20] task sdk: conf changes --- task-sdk/src/airflow/sdk/configuration.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/task-sdk/src/airflow/sdk/configuration.py b/task-sdk/src/airflow/sdk/configuration.py index 4e438a2cbf737..d7b3e2c926aba 100644 --- a/task-sdk/src/airflow/sdk/configuration.py +++ b/task-sdk/src/airflow/sdk/configuration.py @@ -32,6 +32,7 @@ configure_parser_from_configuration_description, expand_env_var, ) +from airflow.sdk._shared.module_loading import import_string from airflow.sdk.execution_time.secrets import _SERVER_DEFAULT_SECRETS_SEARCH_PATH log = logging.getLogger(__name__) @@ -210,6 +211,21 @@ def remove_all_read_configurations(self): self.remove_section(section) +def get_state_backend(): + """ + Get the state backend if configured via ``[workers] state_backend``. + + Returns the instantiated backend, or ``None`` if not configured. + """ + # Lazy import to trigger __getattr__ and lazy initialization + from airflow.sdk.configuration import conf + + class_name = conf.get("workers", "state_backend", fallback="") + if not class_name: + return None + return import_string(class_name)() + + def get_custom_secret_backend(worker_mode: bool = False): """ Get Secret Backend if defined in airflow.cfg. @@ -236,8 +252,6 @@ def initialize_secrets_backends( Uses SDK's conf instead of Core's conf. """ - from airflow.sdk._shared.module_loading import import_string - backend_list = [] worker_mode = False # Determine worker mode - if default_backends is not the server default, it's worker mode From 0656cef5b6b0a075bcb99c0dc69b78d977463400 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Wed, 13 May 2026 16:28:37 +0530 Subject: [PATCH 06/20] task sdk: context accessor changes --- task-sdk/pyproject.toml | 3 + task-sdk/src/airflow/sdk/_shared/state | 1 + .../src/airflow/sdk/execution_time/context.py | 115 +++++++++- task-sdk/src/airflow/sdk/state.py | 21 ++ .../task_sdk/execution_time/test_context.py | 217 +++++++++++++++++- 5 files changed, 348 insertions(+), 9 deletions(-) create mode 120000 task-sdk/src/airflow/sdk/_shared/state create mode 100644 task-sdk/src/airflow/sdk/state.py diff --git a/task-sdk/pyproject.toml b/task-sdk/pyproject.toml index 6e4a0b1017d01..79cc1938d240b 100644 --- a/task-sdk/pyproject.toml +++ b/task-sdk/pyproject.toml @@ -147,6 +147,7 @@ path = "src/airflow/sdk/__init__.py" "../shared/listeners/src/airflow_shared/listeners" = "src/airflow/sdk/_shared/listeners" "../shared/plugins_manager/src/airflow_shared/plugins_manager" = "src/airflow/sdk/_shared/plugins_manager" "../shared/providers_discovery/src/airflow_shared/providers_discovery" = "src/airflow/sdk/_shared/providers_discovery" +"../shared/state/src/airflow_shared/state" = "src/airflow/sdk/_shared/state" "../shared/template_rendering/src/airflow_shared/template_rendering" = "src/airflow/sdk/_shared/template_rendering" [tool.hatch.build.targets.wheel] @@ -240,6 +241,7 @@ apache-airflow = {workspace = true} apache-airflow-devel-common = {workspace = true} apache-airflow-providers-common-sql = {workspace = true} apache-airflow-providers-standard = {workspace = true} +apache-airflow-shared-state = {workspace = true} # To use: # @@ -316,6 +318,7 @@ shared_distributions = [ "apache-airflow-shared-secrets-backend", "apache-airflow-shared-secrets-masker", "apache-airflow-shared-serialization", + "apache-airflow-shared-state", "apache-airflow-shared-timezones", "apache-airflow-shared-observability", "apache-airflow-shared-plugins-manager", diff --git a/task-sdk/src/airflow/sdk/_shared/state b/task-sdk/src/airflow/sdk/_shared/state new file mode 120000 index 0000000000000..cb2f9414b9c9e --- /dev/null +++ b/task-sdk/src/airflow/sdk/_shared/state @@ -0,0 +1 @@ +/Users/amoghdesai/Documents/OSS/repos/airflow/shared/state/src/airflow_shared/state \ No newline at end of file diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py b/task-sdk/src/airflow/sdk/execution_time/context.py index 7fdd8bddc7fc7..506a664eaf5ed 100644 --- a/task-sdk/src/airflow/sdk/execution_time/context.py +++ b/task-sdk/src/airflow/sdk/execution_time/context.py @@ -405,6 +405,19 @@ def get(self, key, default: Any = NOTSET) -> Any: raise +@cache +def _get_worker_state_backend(): + """ + Return the configured worker-side state backend, instantiated once and cached. + + # TODO: rebase / include https://github.com/apache/airflow/pull/66699 once merged + # to also forward ``retention_days`` through the comms layer. + """ + from airflow.sdk.configuration import get_state_backend + + return get_state_backend() + + class TaskStateAccessor: """Accessor for task state scoped to the current task instance. Available as ``context['task_state']`` at task execution time.""" @@ -435,7 +448,11 @@ def get(self, key: str) -> str | None: if isinstance(resp, ErrorResponse) and resp.error != ErrorType.TASK_STATE_NOT_FOUND: raise AirflowRuntimeError(resp) if isinstance(resp, TaskStateResult): - return resp.value + stored = resp.value + # if custom backend is configured, the stored value in DB is a reference, fetch the actual value from + # custom backend using the reference + backend = _get_worker_state_backend() + return backend.deserialize_task_state_value(stored) if backend else stored return None def set(self, key: str, value: str) -> None: @@ -443,13 +460,33 @@ def set(self, key: str, value: str) -> None: from airflow.sdk.execution_time.comms import SetTaskState from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - SUPERVISOR_COMMS.send(SetTaskState(ti_id=self._ti_id, key=key, value=value)) + # if custom backend is configured, store the value on the custom backend, and return the reference + # to the stored value to store in the DB + backend = _get_worker_state_backend() + stored = ( + backend.serialize_task_state_value(value=value, key=key, ti_id=str(self._ti_id)) + if backend + else value + ) + SUPERVISOR_COMMS.send(SetTaskState(ti_id=self._ti_id, key=key, value=stored)) def delete(self, key: str) -> None: """Delete a single key. No-op if the key does not exist.""" - from airflow.sdk.execution_time.comms import DeleteTaskState + from airflow.sdk.execution_time.comms import ( + DeleteTaskState, + GetTaskState, + TaskStateResult, + ) from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + backend = _get_worker_state_backend() + # if custom backend is configured, fetch the reference of the stored value from DB, + # and delete the actual value from custom backend using the reference, then delete the reference from DB + # as well + if backend is not None: + resp = SUPERVISOR_COMMS.send(GetTaskState(ti_id=self._ti_id, key=key)) + if isinstance(resp, TaskStateResult): + backend.purge_task_state(resp.value) SUPERVISOR_COMMS.send(DeleteTaskState(ti_id=self._ti_id, key=key)) def clear(self, all_map_indices: bool = False) -> None: @@ -460,9 +497,18 @@ def clear(self, all_map_indices: bool = False) -> None: instance of the task (fleet-wide reset). Defaults to clearing only this task instance's own state. """ - from airflow.sdk.execution_time.comms import ClearTaskState + from airflow.sdk.execution_time.comms import AllTaskStateResult, ClearTaskState, GetAllTaskState from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + # if custom backend is configured, fetch the references of all stored values for this task instance + # from DB, and delete the actual values from custom backend using the references, then delete + # the references from DB as well. + backend = _get_worker_state_backend() + if backend is not None: + resp = SUPERVISOR_COMMS.send(GetAllTaskState(ti_id=self._ti_id)) + if isinstance(resp, AllTaskStateResult): + for item in resp.items: + backend.purge_task_state(item.value) SUPERVISOR_COMMS.send(ClearTaskState(ti_id=self._ti_id, all_map_indices=all_map_indices)) @@ -513,7 +559,11 @@ def get(self, key: str) -> str | None: if isinstance(resp, ErrorResponse) and resp.error != ErrorType.ASSET_STATE_NOT_FOUND: raise AirflowRuntimeError(resp) if isinstance(resp, AssetStateResult): - return resp.value + stored = resp.value + # if custom backend is configured, the stored value in DB is a reference, fetch the actual value from + # custom backend using the reference + backend = _get_worker_state_backend() + return backend.deserialize_asset_state_value(stored) if backend else stored return None def set(self, key: str, value: str) -> None: @@ -521,22 +571,49 @@ def set(self, key: str, value: str) -> None: from airflow.sdk.execution_time.comms import SetAssetStateByName, SetAssetStateByUri, ToSupervisor from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + # if custom backend is configured, store the value on the custom backend, and return the reference + # to the stored value to store in the DB + backend = _get_worker_state_backend() + asset_name = self._name or self._uri or "" + stored = ( + backend.serialize_asset_state_value(value=value, key=key, asset_name=asset_name) + if backend + else value + ) + msg: ToSupervisor if self._name: - msg = SetAssetStateByName(name=self._name, key=key, value=value) + msg = SetAssetStateByName(name=self._name, key=key, value=stored) elif self._uri: - msg = SetAssetStateByUri(uri=self._uri, key=key, value=value) + msg = SetAssetStateByUri(uri=self._uri, key=key, value=stored) SUPERVISOR_COMMS.send(msg) def delete(self, key: str) -> None: """Delete a single key. No-op if the key does not exist.""" from airflow.sdk.execution_time.comms import ( + AssetStateResult, DeleteAssetStateByName, DeleteAssetStateByUri, + GetAssetStateByName, + GetAssetStateByUri, ToSupervisor, ) from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + backend = _get_worker_state_backend() + # if custom backend is configured, fetch the reference of the stored value from DB, + # and delete the actual value from custom backend using the reference, then delete the reference from DB + # as well + if backend is not None: + get_msg: ToSupervisor + if self._name: + get_msg = GetAssetStateByName(name=self._name, key=key) + elif self._uri: + get_msg = GetAssetStateByUri(uri=self._uri, key=key) + resp = SUPERVISOR_COMMS.send(get_msg) + if isinstance(resp, AssetStateResult): + backend.purge_asset_state(resp.value) + msg: ToSupervisor if self._name: msg = DeleteAssetStateByName(name=self._name, key=key) @@ -546,9 +623,31 @@ def delete(self, key: str) -> None: def clear(self) -> None: """Delete all state keys for this asset.""" - from airflow.sdk.execution_time.comms import ClearAssetStateByName, ClearAssetStateByUri, ToSupervisor + from airflow.sdk.execution_time.comms import ( + AllAssetStateResult, + ClearAssetStateByName, + ClearAssetStateByUri, + GetAllAssetStateByName, + GetAllAssetStateByUri, + ToSupervisor, + ) from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + # if custom backend is configured, fetch the references of all stored values for this asset + # from DB, and delete the actual values from custom backend using the references, then delete + # the references from DB as well. + backend = _get_worker_state_backend() + if backend is not None: + list_msg: ToSupervisor + if self._name: + list_msg = GetAllAssetStateByName(name=self._name) + elif self._uri: + list_msg = GetAllAssetStateByUri(uri=self._uri) + resp = SUPERVISOR_COMMS.send(list_msg) + if isinstance(resp, AllAssetStateResult): + for item in resp.items: + backend.purge_asset_state(item.value) + msg: ToSupervisor if self._name: msg = ClearAssetStateByName(name=self._name) diff --git a/task-sdk/src/airflow/sdk/state.py b/task-sdk/src/airflow/sdk/state.py new file mode 100644 index 0000000000000..21f9acd54850d --- /dev/null +++ b/task-sdk/src/airflow/sdk/state.py @@ -0,0 +1,21 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from airflow.sdk._shared.state import BaseStateBackend as BaseStateBackend diff --git a/task-sdk/tests/task_sdk/execution_time/test_context.py b/task-sdk/tests/task_sdk/execution_time/test_context.py index ff0e6025c637d..eb4dadd0e2c9e 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_context.py +++ b/task-sdk/tests/task_sdk/execution_time/test_context.py @@ -24,7 +24,13 @@ import pytest from airflow.sdk import BaseOperator, get_current_context, timezone -from airflow.sdk.api.datamodels._generated import AssetEventResponse, AssetResponse, DagRun +from airflow.sdk.api.datamodels._generated import ( + AssetEventResponse, + AssetResponse, + AssetStateItem, + DagRun, + TaskStateItem, +) from airflow.sdk.bases.xcom import BaseXCom from airflow.sdk.definitions.asset import ( Asset, @@ -39,6 +45,8 @@ from airflow.sdk.definitions.variable import Variable from airflow.sdk.exceptions import AirflowNotFoundException, AirflowRuntimeError, ErrorType from airflow.sdk.execution_time.comms import ( + AllAssetStateResult, + AllTaskStateResult, AssetEventDagRunReferenceResult, AssetEventResult, AssetEventSourceTaskInstance, @@ -91,6 +99,7 @@ set_current_context, ) from airflow.sdk.execution_time.secrets import ExecutionAPISecretsBackend +from airflow.sdk.state import BaseStateBackend def test_convert_connection_result_conn(): @@ -1317,3 +1326,209 @@ def test_alias_inlet_no_resolved_assets_contributes_nothing(self, mock_superviso accessors = AssetStateAccessors([alias]) assert accessors._total == 0 + + +class InMemoryStateBackend(BaseStateBackend): + """Concrete worker-side test backend — stores values in a dict, returns mem:// refs.""" + + def __init__(self): + self._store: dict[str, str] = {} + self.purged: list[str] = [] + + def serialize_task_state_value(self, *, value: str, key: str, ti_id: str) -> str: + ref = f"mem://{ti_id}/{key}" + self._store[ref] = value + return ref + + def deserialize_task_state_value(self, stored: str) -> str: + if stored.startswith("mem://"): + return self._store.get(stored, stored) + return stored + + def serialize_asset_state_value(self, *, value: str, key: str, asset_name: str) -> str: + ref = f"mem://{asset_name}/{key}" + self._store[ref] = value + return ref + + def deserialize_asset_state_value(self, stored: str) -> str: + if stored.startswith("mem://"): + return self._store.get(stored, stored) + return stored + + def purge_task_state(self, stored: str) -> None: + self._store.pop(stored, None) + self.purged.append(stored) + + def purge_asset_state(self, stored: str) -> None: + self._store.pop(stored, None) + self.purged.append(stored) + + def get(self, scope, key, *, session=None): ... + def set(self, scope, key, value, *, retention_days=None, session=None): ... + def delete(self, scope, key, *, session=None): ... + def clear(self, scope, *, all_map_indices=False, session=None): ... + async def aget(self, scope, key): ... + async def aset(self, scope, key, value, *, retention_days=None): ... + async def adelete(self, scope, key): ... + async def aclear(self, scope, *, all_map_indices=False): ... + + +class TestTaskStateAccessorWithCustomBackend: + TI_ID = UUID("01900000-0000-0000-0000-000000000002") + + @pytest.fixture(autouse=True) + def backend(self): + b = InMemoryStateBackend() + with mock.patch( + "airflow.sdk.execution_time.context._get_worker_state_backend", + return_value=b, + ): + yield b + + def test_set_returns_reference_to_storage(self, mock_supervisor_comms, backend): + """set() stores actual value in backend and returns mem:// reference via comms.""" + mock_supervisor_comms.send.return_value = OKResponse(ok=True) + expected_ref = f"mem://{self.TI_ID}/job_id" + + TaskStateAccessor(ti_id=self.TI_ID).set("job_id", "app_001") + # comms message has the mem:// reference, not the actual value + mock_supervisor_comms.send.assert_called_once_with( + SetTaskState(ti_id=self.TI_ID, key="job_id", value=expected_ref) + ) + # on backend, the value is stored under the mem:// reference + assert backend._store[expected_ref] == "app_001" + + def test_get_resolves_reference_to_actual_value(self, mock_supervisor_comms, backend): + """get() fetches mem:// reference from DB, resolves it to actual value via backend.""" + ref = f"mem://{self.TI_ID}/job_id" + backend._store[ref] = "app_001" + mock_supervisor_comms.send.return_value = TaskStateResult(value=ref) + + result = TaskStateAccessor(ti_id=self.TI_ID).get("job_id") + # actual value is resolved from mem:// reference via backend + assert result == "app_001" + + def test_delete_purges_from_backend_and_removes_db_ref(self, mock_supervisor_comms, backend): + """delete() purges from backend storage and removes the DB reference.""" + ref = f"mem://{self.TI_ID}/job_id" + backend._store[ref] = "app_001" + mock_supervisor_comms.send.side_effect = [ + TaskStateResult(value=ref), + OKResponse(ok=True), + ] + + TaskStateAccessor(ti_id=self.TI_ID).delete("job_id") + + # backend doesn't have the value anymore + assert ref not in backend._store + assert ref in backend.purged + + # request to delete reference in DB was made + mock_supervisor_comms.send.assert_any_call(DeleteTaskState(ti_id=self.TI_ID, key="job_id")) + + def test_clear_purges_all_from_backend_and_clears_db(self, mock_supervisor_comms, backend): + """clear() purges all backend objects for the TI and removes all DB references.""" + + ref_a = f"mem://{self.TI_ID}/job_id" + ref_b = f"mem://{self.TI_ID}/checkpoint" + backend._store[ref_a] = "app_001" + backend._store[ref_b] = "step_3" + + mock_supervisor_comms.send.side_effect = [ + AllTaskStateResult.from_api_response( + [ + TaskStateItem(key="job_id", value=ref_a), + TaskStateItem(key="checkpoint", value=ref_b), + ] + ), + OKResponse(ok=True), + ] + + TaskStateAccessor(ti_id=self.TI_ID).clear() + + assert ref_a not in backend._store + assert ref_b not in backend._store + assert backend.purged == [ref_a, ref_b] + mock_supervisor_comms.send.assert_any_call(ClearTaskState(ti_id=self.TI_ID, all_map_indices=False)) + + +class TestAssetStateAccessorWithCustomBackend: + ASSET_NAME = "my_asset" + + @pytest.fixture(autouse=True) + def backend(self): + b = InMemoryStateBackend() + with mock.patch( + "airflow.sdk.execution_time.context._get_worker_state_backend", + return_value=b, + ): + yield b + + def test_set_sends_reference_not_value(self, mock_supervisor_comms, backend): + """set() stores actual value in backend and sends mem:// reference via comms.""" + mock_supervisor_comms.send.return_value = OKResponse(ok=True) + + AssetStateAccessor(name=self.ASSET_NAME).set("watermark", "2026-05-01") + + expected_ref = f"mem://{self.ASSET_NAME}/watermark" + # comms message has the mem:// reference, not the actual value + mock_supervisor_comms.send.assert_called_once_with( + SetAssetStateByName(name=self.ASSET_NAME, key="watermark", value=expected_ref) + ) + # on backend, the value is stored under the mem:// reference + assert backend._store[expected_ref] == "2026-05-01" + + def test_get_resolves_reference_to_actual_value(self, mock_supervisor_comms, backend): + """get() fetches mem:// reference from DB, resolves it to actual value via backend.""" + ref = f"mem://{self.ASSET_NAME}/watermark" + backend._store[ref] = "2026-05-01" + mock_supervisor_comms.send.return_value = AssetStateResult(value=ref) + + result = AssetStateAccessor(name=self.ASSET_NAME).get("watermark") + + # actual value is resolved from mem:// reference via backend + assert result == "2026-05-01" + + def test_delete_purges_from_backend_and_removes_db_ref(self, mock_supervisor_comms, backend): + """delete() purges from backend storage and removes the DB reference.""" + ref = f"mem://{self.ASSET_NAME}/watermark" + backend._store[ref] = "2026-05-01" + mock_supervisor_comms.send.side_effect = [ + AssetStateResult(value=ref), + OKResponse(ok=True), + ] + + AssetStateAccessor(name=self.ASSET_NAME).delete("watermark") + + # backend doesn't have the value anymore + assert ref not in backend._store + assert ref in backend.purged + + # request to delete reference in DB was made + mock_supervisor_comms.send.assert_any_call( + DeleteAssetStateByName(name=self.ASSET_NAME, key="watermark") + ) + + def test_clear_purges_all_from_backend_and_clears_db(self, mock_supervisor_comms, backend): + """clear() purges all backend objects and removes all DB references.""" + ref_a = f"mem://{self.ASSET_NAME}/watermark" + ref_b = f"mem://{self.ASSET_NAME}/file_count" + backend._store[ref_a] = "2026-05-01" + backend._store[ref_b] = "42" + + mock_supervisor_comms.send.side_effect = [ + AllAssetStateResult.from_api_response( + [ + AssetStateItem(key="watermark", value=ref_a), + AssetStateItem(key="file_count", value=ref_b), + ] + ), + OKResponse(ok=True), + ] + + AssetStateAccessor(name=self.ASSET_NAME).clear() + + assert ref_a not in backend._store + assert ref_b not in backend._store + assert backend.purged == [ref_a, ref_b] + mock_supervisor_comms.send.assert_any_call(ClearAssetStateByName(name=self.ASSET_NAME)) From 75156c44bed2b76bddbabf5e1969b8e93001399c Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Wed, 13 May 2026 16:35:34 +0530 Subject: [PATCH 07/20] task sdk: task runner changes --- .../airflow/sdk/execution_time/task_runner.py | 8 ++ .../execution_time/test_task_runner.py | 120 ++++++++++++++++++ 2 files changed, 128 insertions(+) diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 7c318fc499ed6..d443b49eb9d8b 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -1429,6 +1429,14 @@ def _handle_current_task_success( stats.incr("operator_successes", tags={**stats_tags, "operator_name": operator}) stats.incr("ti_successes", tags=stats_tags) + # TODO: uncomment below once https://github.com/apache/airflow/pull/66699 is merged + # if conf.getboolean("state_store", "clear_on_success"): + # log.info("Task state will be cleared by the server because clear_on_success is enabled.") + # + # if _get_worker_state_backend() is not None: + # # clear the task state keys for custom state backends configured on worker side + # context["task_state"].clear() + task_outlets = list(_build_asset_profiles(ti.task.outlets)) outlet_events = list(_serialize_outlet_events(context["outlet_events"])) msg = SucceedTask( diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 723ca42d93aa6..0f135a6e21f75 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -5104,3 +5104,123 @@ def execute(self, context): mock_supervisor_comms.send.assert_any_call( SetAssetStateByName(name="asset_b", key="watermark_b", value="2026-05-02") ) + + @conf_vars({("state_store", "clear_on_success"): "True"}) + def test_clear_on_success_calls_clear_when_worker_backend_configured( + self, create_runtime_ti, mock_supervisor_comms + ): + """When clear_on_success=True and a worker backend is configured, clear() is called on task success.""" + mock_backend = mock.MagicMock() + + class MyOperator(BaseOperator): + def execute(self, context): + pass + + task = MyOperator(task_id="t") + runtime_ti = create_runtime_ti(task=task) + + with mock.patch( + "airflow.sdk.execution_time.task_runner._get_worker_state_backend", return_value=mock_backend + ): + run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) + + mock_supervisor_comms.send.assert_any_call(ClearTaskState(ti_id=runtime_ti.id, all_map_indices=False)) + + @conf_vars({("state_store", "clear_on_success"): "True"}) + def test_clear_on_success_skips_when_no_worker_backend(self, create_runtime_ti, mock_supervisor_comms): + """When clear_on_success=True but no worker backend configured, clear() is not called.""" + + class MyOperator(BaseOperator): + def execute(self, context): + pass + + task = MyOperator(task_id="t") + runtime_ti = create_runtime_ti(task=task) + + with mock.patch( + "airflow.sdk.execution_time.task_runner._get_worker_state_backend", return_value=None + ): + run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) + + calls = [str(c) for c in mock_supervisor_comms.send.call_args_list] + assert not any("ClearTaskState" in c for c in calls) + + @conf_vars({("state_store", "clear_on_success"): "False"}) + def test_clear_on_success_disabled_does_not_call_clear(self, create_runtime_ti, mock_supervisor_comms): + """When clear_on_success=False, clear() is not called even if a worker backend is configured.""" + mock_backend = mock.MagicMock() + + class MyOperator(BaseOperator): + def execute(self, context): + pass + + task = MyOperator(task_id="t") + runtime_ti = create_runtime_ti(task=task) + + with mock.patch( + "airflow.sdk.execution_time.task_runner._get_worker_state_backend", return_value=mock_backend + ): + run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) + + calls = [str(c) for c in mock_supervisor_comms.send.call_args_list] + assert not any("ClearTaskState" in c for c in calls) + + def test_asset_state_set_sends_reference_via_custom_backend( + self, create_runtime_ti, mock_supervisor_comms + ): + """When a worker backend is configured, asset state set() sends a reference, not the actual value.""" + watched = Asset(name="my_asset", uri="s3://bucket/data") + + class WatcherOperator(BaseOperator): + def execute(self, context): + context["asset_state"].set("watermark", "2026-05-01") + + task = WatcherOperator(task_id="t", inlets=[watched]) + runtime_ti = create_runtime_ti(task=task) + mock_supervisor_comms.send.side_effect = TestTaskInstanceStateOperations._watcher_side_effect + + mock_backend = mock.MagicMock() + mock_backend.serialize_asset_state_value.return_value = "mem://my_asset/watermark" + + with mock.patch( + "airflow.sdk.execution_time.context._get_worker_state_backend", return_value=mock_backend + ): + run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) + + mock_backend.serialize_asset_state_value.assert_called_once_with( + value="2026-05-01", key="watermark", asset_name="my_asset" + ) + mock_supervisor_comms.send.assert_any_call( + SetAssetStateByName(name="my_asset", key="watermark", value="mem://my_asset/watermark") + ) + + def test_task_state_set_sends_reference_via_custom_backend( + self, create_runtime_ti, mock_supervisor_comms + ): + """When a worker backend is configured, task state set() sends a reference, not the actual value.""" + + class MyOperator(BaseOperator): + def execute(self, context): + context["task_state"].set("job_id", "app_001") + + task = MyOperator(task_id="t") + runtime_ti = create_runtime_ti(task=task) + mock_supervisor_comms.send.side_effect = [ + OKResponse(ok=True), # SetTaskState + SucceedTask(...), # finalize + ] + + mock_backend = mock.MagicMock() + mock_backend.serialize_task_state_value.return_value = f"mem://{runtime_ti.id}/job_id" + + with mock.patch( + "airflow.sdk.execution_time.context._get_worker_state_backend", return_value=mock_backend + ): + run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) + + mock_backend.serialize_task_state_value.assert_called_once_with( + value="app_001", key="job_id", ti_id=str(runtime_ti.id) + ) + mock_supervisor_comms.send.assert_any_call( + SetTaskState(ti_id=runtime_ti.id, key="job_id", value=f"mem://{runtime_ti.id}/job_id") + ) From 8d29dc45e2a31a9582465ad79dcf6624ec887f67 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Thu, 14 May 2026 13:54:48 +0530 Subject: [PATCH 08/20] removing wrongly committed file --- task-sdk/src/airflow/sdk/_shared/state | 1 - 1 file changed, 1 deletion(-) delete mode 120000 task-sdk/src/airflow/sdk/_shared/state diff --git a/task-sdk/src/airflow/sdk/_shared/state b/task-sdk/src/airflow/sdk/_shared/state deleted file mode 120000 index cb2f9414b9c9e..0000000000000 --- a/task-sdk/src/airflow/sdk/_shared/state +++ /dev/null @@ -1 +0,0 @@ -/Users/amoghdesai/Documents/OSS/repos/airflow/shared/state/src/airflow_shared/state \ No newline at end of file From 867dc7f78cfe1269dec8c179847d5b3e5d147fde Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Thu, 14 May 2026 14:02:27 +0530 Subject: [PATCH 09/20] fixing CI failures --- airflow-core/src/airflow/config_templates/config.yml | 5 ++--- airflow-core/tests/unit/dag_processing/test_processor.py | 5 +++++ airflow-core/tests/unit/jobs/test_triggerer_job.py | 5 +++++ 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/airflow-core/src/airflow/config_templates/config.yml b/airflow-core/src/airflow/config_templates/config.yml index 6ede6481b374a..94c4a16cae4a6 100644 --- a/airflow-core/src/airflow/config_templates/config.yml +++ b/airflow-core/src/airflow/config_templates/config.yml @@ -1889,10 +1889,9 @@ workers: state_backend: description: | Full class name of the state backend to use on workers for direct task state access, - bypassing the execution API. When set, ``TaskStateAccessor`` calls this backend directly - instead of routing through the supervisor comms path. + bypassing the execution API. - Leave empty (default) to use the standard comms path through the supervisor. + Leave empty (default) to use the standard path through the task sdk supervisor. version_added: 3.3.0 type: string example: "mypackage.state.S3StateBackend" diff --git a/airflow-core/tests/unit/dag_processing/test_processor.py b/airflow-core/tests/unit/dag_processing/test_processor.py index d2ac085b0b736..f6be44e75224d 100644 --- a/airflow-core/tests/unit/dag_processing/test_processor.py +++ b/airflow-core/tests/unit/dag_processing/test_processor.py @@ -1985,6 +1985,9 @@ def get_type_names(union_type): "DeleteAssetStateByUri", "ClearAssetStateByName", "ClearAssetStateByUri", + "GetAllAssetStateByName", + "GetAllAssetStateByUri", + "GetAllTaskState", } in_task_runner_but_not_in_dag_processing_process = { @@ -2007,6 +2010,8 @@ def get_type_names(union_type): # AIP-103 task/asset state results — worker-only responses to the above messages. "TaskStateResult", "AssetStateResult", + "AllAssetStateResult", + "AllTaskStateResult", } supervisor_diff = supervisor_types - manager_types - in_supervisor_but_not_in_manager diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py b/airflow-core/tests/unit/jobs/test_triggerer_job.py index 0501783b992d2..b5c186201ac57 100644 --- a/airflow-core/tests/unit/jobs/test_triggerer_job.py +++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py @@ -1809,6 +1809,9 @@ def get_type_names(union_type): "DeleteAssetStateByUri", "ClearAssetStateByName", "ClearAssetStateByUri", + "GetAllAssetStateByName", + "GetAllAssetStateByUri", + "GetAllTaskState", } in_task_but_not_in_trigger_runner = { @@ -1833,6 +1836,8 @@ def get_type_names(union_type): # AIP-103 task/asset state results — worker-only responses to the above messages. "TaskStateResult", "AssetStateResult", + "AllAssetStateResult", + "AllTaskStateResult", } supervisor_diff = ( From 432b432d8369552737cc1ae3d2dda8cbff40c84f Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Fri, 15 May 2026 18:06:26 +0530 Subject: [PATCH 10/20] ultimate change to simplify workers backend --- .../src/airflow_shared/state/__init__.py | 29 +--- task-sdk/src/airflow/sdk/_shared/state | 1 + .../src/airflow/sdk/execution_time/context.py | 64 ++----- .../airflow/sdk/execution_time/task_runner.py | 11 +- .../task_sdk/execution_time/test_context.py | 163 +++++++----------- 5 files changed, 98 insertions(+), 170 deletions(-) create mode 120000 task-sdk/src/airflow/sdk/_shared/state diff --git a/shared/state/src/airflow_shared/state/__init__.py b/shared/state/src/airflow_shared/state/__init__.py index 534b6300cfd55..68b664e00eab1 100644 --- a/shared/state/src/airflow_shared/state/__init__.py +++ b/shared/state/src/airflow_shared/state/__init__.py @@ -45,9 +45,16 @@ class TaskScope: @dataclass(frozen=True) class AssetScope: - """Identifies the state namespace for an asset.""" + """ + Identifies the state namespace for an asset. + + We need to store ``name`` or ``uri`` since workers do not have access to the integer ``asset_id``. + (Access method is through AssetStateAccessor). + """ - asset_id: int + asset_id: int | None = None + name: str | None = None + uri: str | None = None StateScope = TaskScope | AssetScope @@ -204,21 +211,3 @@ def deserialize_asset_state_value(self, stored: str) -> str: the Execution API. Default: return ``stored`` unchanged. """ return stored - - def purge_task_state(self, stored: str) -> None: - """ - Clean up the task state storage object on the custom backend identified by ``stored``. - - Called by ``TaskStateAccessor.delete()`` and ``TaskStateAccessor.clear()`` before - the DB reference is removed. ``stored`` is whatever ``serialize_task_state_value`` - returned. Default: no-op. - """ - - def purge_asset_state(self, stored: str) -> None: - """ - Clean up the asset state storage object on the custom backend identified by ``stored``. - - Called by ``AssetStateAccessor.delete()`` and ``AssetStateAccessor.clear()`` before - the DB reference is removed. ``stored`` is whatever ``serialize_asset_state_value`` - returned. Default: no-op. - """ diff --git a/task-sdk/src/airflow/sdk/_shared/state b/task-sdk/src/airflow/sdk/_shared/state new file mode 120000 index 0000000000000..cb2f9414b9c9e --- /dev/null +++ b/task-sdk/src/airflow/sdk/_shared/state @@ -0,0 +1 @@ +/Users/amoghdesai/Documents/OSS/repos/airflow/shared/state/src/airflow_shared/state \ No newline at end of file diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py b/task-sdk/src/airflow/sdk/execution_time/context.py index 55a4deb5fd4f8..8f5c02286b5b5 100644 --- a/task-sdk/src/airflow/sdk/execution_time/context.py +++ b/task-sdk/src/airflow/sdk/execution_time/context.py @@ -50,6 +50,7 @@ from typing_extensions import Self from airflow.sdk import Variable + from airflow.sdk._shared.state import TaskScope from airflow.sdk.bases.operator import BaseOperator from airflow.sdk.definitions.connection import Connection from airflow.sdk.definitions.context import Context @@ -450,8 +451,9 @@ def _get_worker_state_backend(): class TaskStateAccessor: """Accessor for task state scoped to the current task instance. Available as ``context['task_state']`` at task execution time.""" - def __init__(self, ti_id: UUID) -> None: + def __init__(self, ti_id: UUID, scope: TaskScope) -> None: self._ti_id = ti_id + self._scope = scope def __eq__(self, other: object) -> bool: if not isinstance(other, TaskStateAccessor): @@ -501,21 +503,12 @@ def set(self, key: str, value: str) -> None: def delete(self, key: str) -> None: """Delete a single key. No-op if the key does not exist.""" - from airflow.sdk.execution_time.comms import ( - DeleteTaskState, - GetTaskState, - TaskStateResult, - ) + from airflow.sdk.execution_time.comms import DeleteTaskState from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS backend = _get_worker_state_backend() - # if custom backend is configured, fetch the reference of the stored value from DB, - # and delete the actual value from custom backend using the reference, then delete the reference from DB - # as well if backend is not None: - resp = SUPERVISOR_COMMS.send(GetTaskState(ti_id=self._ti_id, key=key)) - if isinstance(resp, TaskStateResult): - backend.purge_task_state(resp.value) + backend.delete(self._scope, key) SUPERVISOR_COMMS.send(DeleteTaskState(ti_id=self._ti_id, key=key)) def clear(self, all_map_indices: bool = False) -> None: @@ -526,18 +519,12 @@ def clear(self, all_map_indices: bool = False) -> None: instance of the task (fleet-wide reset). Defaults to clearing only this task instance's own state. """ - from airflow.sdk.execution_time.comms import AllTaskStateResult, ClearTaskState, GetAllTaskState + from airflow.sdk.execution_time.comms import ClearTaskState from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - # if custom backend is configured, fetch the references of all stored values for this task instance - # from DB, and delete the actual values from custom backend using the references, then delete - # the references from DB as well. backend = _get_worker_state_backend() if backend is not None: - resp = SUPERVISOR_COMMS.send(GetAllTaskState(ti_id=self._ti_id)) - if isinstance(resp, AllTaskStateResult): - for item in resp.items: - backend.purge_task_state(item.value) + backend.clear(self._scope, all_map_indices=all_map_indices) SUPERVISOR_COMMS.send(ClearTaskState(ti_id=self._ti_id, all_map_indices=all_map_indices)) @@ -619,29 +606,19 @@ def set(self, key: str, value: str) -> None: def delete(self, key: str) -> None: """Delete a single key. No-op if the key does not exist.""" + from airflow.sdk._shared.state import AssetScope from airflow.sdk.execution_time.comms import ( - AssetStateResult, DeleteAssetStateByName, DeleteAssetStateByUri, - GetAssetStateByName, - GetAssetStateByUri, ToSupervisor, ) from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS backend = _get_worker_state_backend() - # if custom backend is configured, fetch the reference of the stored value from DB, - # and delete the actual value from custom backend using the reference, then delete the reference from DB - # as well + # session=None signals worker-side: backend cleans up external storage only. + # DB reference is removed separately via comms below. if backend is not None: - get_msg: ToSupervisor - if self._name: - get_msg = GetAssetStateByName(name=self._name, key=key) - elif self._uri: - get_msg = GetAssetStateByUri(uri=self._uri, key=key) - resp = SUPERVISOR_COMMS.send(get_msg) - if isinstance(resp, AssetStateResult): - backend.purge_asset_state(resp.value) + backend.delete(AssetScope(name=self._name, uri=self._uri), key) msg: ToSupervisor if self._name: @@ -652,30 +629,19 @@ def delete(self, key: str) -> None: def clear(self) -> None: """Delete all state keys for this asset.""" + from airflow.sdk._shared.state import AssetScope from airflow.sdk.execution_time.comms import ( - AllAssetStateResult, ClearAssetStateByName, ClearAssetStateByUri, - GetAllAssetStateByName, - GetAllAssetStateByUri, ToSupervisor, ) from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - # if custom backend is configured, fetch the references of all stored values for this asset - # from DB, and delete the actual values from custom backend using the references, then delete - # the references from DB as well. backend = _get_worker_state_backend() + # session=None signals worker-side: backend cleans up external storage only. + # DB references are cleared separately via comms below. if backend is not None: - list_msg: ToSupervisor - if self._name: - list_msg = GetAllAssetStateByName(name=self._name) - elif self._uri: - list_msg = GetAllAssetStateByUri(uri=self._uri) - resp = SUPERVISOR_COMMS.send(list_msg) - if isinstance(resp, AllAssetStateResult): - for item in resp.items: - backend.purge_asset_state(item.value) + backend.clear(AssetScope(name=self._name, uri=self._uri)) msg: ToSupervisor if self._name: diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index d443b49eb9d8b..50ad7045f7d3b 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -43,6 +43,7 @@ from airflow.dag_processing.bundles.base import BaseDagBundle, BundleVersionLock from airflow.dag_processing.bundles.manager import DagBundlesManager from airflow.sdk._shared.observability.metrics import stats +from airflow.sdk._shared.state import TaskScope from airflow.sdk._shared.template_rendering import truncate_rendered_value from airflow.sdk.api.client import get_hostname, getuser from airflow.sdk.api.datamodels._generated import ( @@ -251,7 +252,15 @@ def get_template_context(self) -> Context: "value": VariableAccessor(deserialize_json=False), }, "conn": ConnectionAccessor(), - "task_state": TaskStateAccessor(ti_id=self.id), + "task_state": TaskStateAccessor( + ti_id=self.id, + scope=TaskScope( + dag_id=self.dag_id, + run_id=self.run_id, + task_id=self.task_id, + map_index=self.map_index if self.map_index is not None else -1, + ), + ), } if any(isinstance(i, (Asset, AssetNameRef, AssetUriRef, AssetAlias)) for i in self.task.inlets): self._cached_template_context["asset_state"] = AssetStateAccessors(self.task.inlets) diff --git a/task-sdk/tests/task_sdk/execution_time/test_context.py b/task-sdk/tests/task_sdk/execution_time/test_context.py index eb4dadd0e2c9e..acfc9418f7c8b 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_context.py +++ b/task-sdk/tests/task_sdk/execution_time/test_context.py @@ -24,12 +24,11 @@ import pytest from airflow.sdk import BaseOperator, get_current_context, timezone +from airflow.sdk._shared.state import TaskScope from airflow.sdk.api.datamodels._generated import ( AssetEventResponse, AssetResponse, - AssetStateItem, DagRun, - TaskStateItem, ) from airflow.sdk.bases.xcom import BaseXCom from airflow.sdk.definitions.asset import ( @@ -45,8 +44,6 @@ from airflow.sdk.definitions.variable import Variable from airflow.sdk.exceptions import AirflowNotFoundException, AirflowRuntimeError, ErrorType from airflow.sdk.execution_time.comms import ( - AllAssetStateResult, - AllTaskStateResult, AssetEventDagRunReferenceResult, AssetEventResult, AssetEventSourceTaskInstance, @@ -1068,11 +1065,12 @@ def get_connection(self, conn_id): class TestTaskStateAccessor: TI_ID = UUID("01900000-0000-0000-0000-000000000001") + SCOPE = TaskScope(dag_id="dag", run_id="run", task_id="task") def test_get_returns_value(self, mock_supervisor_comms): mock_supervisor_comms.send.return_value = TaskStateResult(value="app_001") - result = TaskStateAccessor(ti_id=self.TI_ID).get("job_id") + result = TaskStateAccessor(ti_id=self.TI_ID, scope=self.SCOPE).get("job_id") assert result == "app_001" mock_supervisor_comms.send.assert_called_once_with(GetTaskState(ti_id=self.TI_ID, key="job_id")) @@ -1082,7 +1080,7 @@ def test_get_returns_none_on_404(self, mock_supervisor_comms): error=ErrorType.TASK_STATE_NOT_FOUND, detail={"key": "missing_key"} ) - result = TaskStateAccessor(ti_id=self.TI_ID).get("missing_key") + result = TaskStateAccessor(ti_id=self.TI_ID, scope=self.SCOPE).get("missing_key") assert result is None @@ -1092,12 +1090,12 @@ def test_get_raises_on_error(self, mock_supervisor_comms): ) with pytest.raises(AirflowRuntimeError): - TaskStateAccessor(ti_id=self.TI_ID).get("some_key") + TaskStateAccessor(ti_id=self.TI_ID, scope=self.SCOPE).get("some_key") def test_set_operation(self, mock_supervisor_comms): mock_supervisor_comms.send.return_value = OKResponse(ok=True) - TaskStateAccessor(ti_id=self.TI_ID).set("job_id", "app_001") + TaskStateAccessor(ti_id=self.TI_ID, scope=self.SCOPE).set("job_id", "app_001") mock_supervisor_comms.send.assert_called_once_with( SetTaskState(ti_id=self.TI_ID, key="job_id", value="app_001") @@ -1106,14 +1104,14 @@ def test_set_operation(self, mock_supervisor_comms): def test_delete_operation(self, mock_supervisor_comms): mock_supervisor_comms.send.return_value = OKResponse(ok=True) - TaskStateAccessor(ti_id=self.TI_ID).delete("job_id") + TaskStateAccessor(ti_id=self.TI_ID, scope=self.SCOPE).delete("job_id") mock_supervisor_comms.send.assert_called_once_with(DeleteTaskState(ti_id=self.TI_ID, key="job_id")) def test_clear_default_sends_all_map_indices_false(self, mock_supervisor_comms): mock_supervisor_comms.send.return_value = OKResponse(ok=True) - TaskStateAccessor(ti_id=self.TI_ID).clear() + TaskStateAccessor(ti_id=self.TI_ID, scope=self.SCOPE).clear() mock_supervisor_comms.send.assert_called_once_with( ClearTaskState(ti_id=self.TI_ID, all_map_indices=False) @@ -1122,7 +1120,7 @@ def test_clear_default_sends_all_map_indices_false(self, mock_supervisor_comms): def test_clear_all_map_indices_sends_flag_true(self, mock_supervisor_comms): mock_supervisor_comms.send.return_value = OKResponse(ok=True) - TaskStateAccessor(ti_id=self.TI_ID).clear(all_map_indices=True) + TaskStateAccessor(ti_id=self.TI_ID, scope=self.SCOPE).clear(all_map_indices=True) mock_supervisor_comms.send.assert_called_once_with( ClearTaskState(ti_id=self.TI_ID, all_map_indices=True) @@ -1329,52 +1327,52 @@ def test_alias_inlet_no_resolved_assets_contributes_nothing(self, mock_superviso class InMemoryStateBackend(BaseStateBackend): - """Concrete worker-side test backend — stores values in a dict, returns mem:// refs.""" + """Simple in-memory test backend.""" def __init__(self): - self._store: dict[str, str] = {} - self.purged: list[str] = [] + self._actual_key_value_store: dict[str, str] = {} # key -> actual value + self.reference: dict[str, str] = {} # key -> stored ref (mem:// URI) def serialize_task_state_value(self, *, value: str, key: str, ti_id: str) -> str: ref = f"mem://{ti_id}/{key}" - self._store[ref] = value + self._actual_key_value_store[key] = value + self.reference[key] = ref return ref def deserialize_task_state_value(self, stored: str) -> str: - if stored.startswith("mem://"): - return self._store.get(stored, stored) - return stored + key = stored.rsplit("/", 1)[-1] + return self._actual_key_value_store.get(key, stored) def serialize_asset_state_value(self, *, value: str, key: str, asset_name: str) -> str: ref = f"mem://{asset_name}/{key}" - self._store[ref] = value + self._actual_key_value_store[key] = value + self.reference[key] = ref return ref def deserialize_asset_state_value(self, stored: str) -> str: - if stored.startswith("mem://"): - return self._store.get(stored, stored) - return stored + key = stored.rsplit("/", 1)[-1] + return self._actual_key_value_store.get(key, stored) - def purge_task_state(self, stored: str) -> None: - self._store.pop(stored, None) - self.purged.append(stored) + def get(self, scope, key, *, session=None): ... + def set(self, scope, key, value, *, session=None): ... - def purge_asset_state(self, stored: str) -> None: - self._store.pop(stored, None) - self.purged.append(stored) + def delete(self, scope, key, *, session=None) -> None: + self._actual_key_value_store.pop(key, None) + self.reference.pop(key, None) + + def clear(self, scope, *, all_map_indices=False, session=None) -> None: + self._actual_key_value_store.clear() + self.reference.clear() - def get(self, scope, key, *, session=None): ... - def set(self, scope, key, value, *, retention_days=None, session=None): ... - def delete(self, scope, key, *, session=None): ... - def clear(self, scope, *, all_map_indices=False, session=None): ... async def aget(self, scope, key): ... - async def aset(self, scope, key, value, *, retention_days=None): ... + async def aset(self, scope, key, value): ... async def adelete(self, scope, key): ... async def aclear(self, scope, *, all_map_indices=False): ... class TestTaskStateAccessorWithCustomBackend: TI_ID = UUID("01900000-0000-0000-0000-000000000002") + SCOPE = TaskScope(dag_id="dag", run_id="run", task_id="task") @pytest.fixture(autouse=True) def backend(self): @@ -1386,69 +1384,51 @@ def backend(self): yield b def test_set_returns_reference_to_storage(self, mock_supervisor_comms, backend): - """set() stores actual value in backend and returns mem:// reference via comms.""" + """set() stores actual value in backend and sends mem:// reference via comms.""" mock_supervisor_comms.send.return_value = OKResponse(ok=True) expected_ref = f"mem://{self.TI_ID}/job_id" - TaskStateAccessor(ti_id=self.TI_ID).set("job_id", "app_001") + TaskStateAccessor(ti_id=self.TI_ID, scope=self.SCOPE).set("job_id", "app_001") # comms message has the mem:// reference, not the actual value mock_supervisor_comms.send.assert_called_once_with( SetTaskState(ti_id=self.TI_ID, key="job_id", value=expected_ref) ) - # on backend, the value is stored under the mem:// reference - assert backend._store[expected_ref] == "app_001" + # actual value is stored on the backend, reference is stored for DB + assert backend._actual_key_value_store["job_id"] == "app_001" + assert backend.reference["job_id"] == expected_ref def test_get_resolves_reference_to_actual_value(self, mock_supervisor_comms, backend): """get() fetches mem:// reference from DB, resolves it to actual value via backend.""" ref = f"mem://{self.TI_ID}/job_id" - backend._store[ref] = "app_001" + backend._actual_key_value_store["job_id"] = "app_001" mock_supervisor_comms.send.return_value = TaskStateResult(value=ref) - result = TaskStateAccessor(ti_id=self.TI_ID).get("job_id") + result = TaskStateAccessor(ti_id=self.TI_ID, scope=self.SCOPE).get("job_id") # actual value is resolved from mem:// reference via backend assert result == "app_001" - def test_delete_purges_from_backend_and_removes_db_ref(self, mock_supervisor_comms, backend): + def test_deletes_from_backend_and_removes_db_ref(self, mock_supervisor_comms, backend): """delete() purges from backend storage and removes the DB reference.""" - ref = f"mem://{self.TI_ID}/job_id" - backend._store[ref] = "app_001" - mock_supervisor_comms.send.side_effect = [ - TaskStateResult(value=ref), - OKResponse(ok=True), - ] - - TaskStateAccessor(ti_id=self.TI_ID).delete("job_id") + backend._actual_key_value_store["job_id"] = "app_001" + mock_supervisor_comms.send.return_value = OKResponse(ok=True) - # backend doesn't have the value anymore - assert ref not in backend._store - assert ref in backend.purged + TaskStateAccessor(ti_id=self.TI_ID, scope=self.SCOPE).delete("job_id") + # backend does not have the value anymore + assert "job_id" not in backend._actual_key_value_store # request to delete reference in DB was made mock_supervisor_comms.send.assert_any_call(DeleteTaskState(ti_id=self.TI_ID, key="job_id")) - def test_clear_purges_all_from_backend_and_clears_db(self, mock_supervisor_comms, backend): + def test_clears_all_from_backend_and_clears_db(self, mock_supervisor_comms, backend): """clear() purges all backend objects for the TI and removes all DB references.""" + backend._actual_key_value_store["job_id"] = "app_001" + backend._actual_key_value_store["checkpoint"] = "step_3" + mock_supervisor_comms.send.return_value = OKResponse(ok=True) - ref_a = f"mem://{self.TI_ID}/job_id" - ref_b = f"mem://{self.TI_ID}/checkpoint" - backend._store[ref_a] = "app_001" - backend._store[ref_b] = "step_3" - - mock_supervisor_comms.send.side_effect = [ - AllTaskStateResult.from_api_response( - [ - TaskStateItem(key="job_id", value=ref_a), - TaskStateItem(key="checkpoint", value=ref_b), - ] - ), - OKResponse(ok=True), - ] - - TaskStateAccessor(ti_id=self.TI_ID).clear() + TaskStateAccessor(ti_id=self.TI_ID, scope=self.SCOPE).clear() - assert ref_a not in backend._store - assert ref_b not in backend._store - assert backend.purged == [ref_a, ref_b] + assert "job_id" not in backend._actual_key_value_store + assert "checkpoint" not in backend._actual_key_value_store mock_supervisor_comms.send.assert_any_call(ClearTaskState(ti_id=self.TI_ID, all_map_indices=False)) @@ -1475,13 +1455,14 @@ def test_set_sends_reference_not_value(self, mock_supervisor_comms, backend): mock_supervisor_comms.send.assert_called_once_with( SetAssetStateByName(name=self.ASSET_NAME, key="watermark", value=expected_ref) ) - # on backend, the value is stored under the mem:// reference - assert backend._store[expected_ref] == "2026-05-01" + # actual value is stored on the backend, reference is stored for DB + assert backend._actual_key_value_store["watermark"] == "2026-05-01" + assert backend.reference["watermark"] == expected_ref def test_get_resolves_reference_to_actual_value(self, mock_supervisor_comms, backend): """get() fetches mem:// reference from DB, resolves it to actual value via backend.""" ref = f"mem://{self.ASSET_NAME}/watermark" - backend._store[ref] = "2026-05-01" + backend._actual_key_value_store["watermark"] = "2026-05-01" mock_supervisor_comms.send.return_value = AssetStateResult(value=ref) result = AssetStateAccessor(name=self.ASSET_NAME).get("watermark") @@ -1491,19 +1472,13 @@ def test_get_resolves_reference_to_actual_value(self, mock_supervisor_comms, bac def test_delete_purges_from_backend_and_removes_db_ref(self, mock_supervisor_comms, backend): """delete() purges from backend storage and removes the DB reference.""" - ref = f"mem://{self.ASSET_NAME}/watermark" - backend._store[ref] = "2026-05-01" - mock_supervisor_comms.send.side_effect = [ - AssetStateResult(value=ref), - OKResponse(ok=True), - ] + backend._actual_key_value_store["watermark"] = "2026-05-01" + mock_supervisor_comms.send.return_value = OKResponse(ok=True) AssetStateAccessor(name=self.ASSET_NAME).delete("watermark") # backend doesn't have the value anymore - assert ref not in backend._store - assert ref in backend.purged - + assert "watermark" not in backend._actual_key_value_store # request to delete reference in DB was made mock_supervisor_comms.send.assert_any_call( DeleteAssetStateByName(name=self.ASSET_NAME, key="watermark") @@ -1511,24 +1486,12 @@ def test_delete_purges_from_backend_and_removes_db_ref(self, mock_supervisor_com def test_clear_purges_all_from_backend_and_clears_db(self, mock_supervisor_comms, backend): """clear() purges all backend objects and removes all DB references.""" - ref_a = f"mem://{self.ASSET_NAME}/watermark" - ref_b = f"mem://{self.ASSET_NAME}/file_count" - backend._store[ref_a] = "2026-05-01" - backend._store[ref_b] = "42" - - mock_supervisor_comms.send.side_effect = [ - AllAssetStateResult.from_api_response( - [ - AssetStateItem(key="watermark", value=ref_a), - AssetStateItem(key="file_count", value=ref_b), - ] - ), - OKResponse(ok=True), - ] + backend._actual_key_value_store["watermark"] = "2026-05-01" + backend._actual_key_value_store["file_count"] = "42" + mock_supervisor_comms.send.return_value = OKResponse(ok=True) AssetStateAccessor(name=self.ASSET_NAME).clear() - assert ref_a not in backend._store - assert ref_b not in backend._store - assert backend.purged == [ref_a, ref_b] + assert "watermark" not in backend._actual_key_value_store + assert "file_count" not in backend._actual_key_value_store mock_supervisor_comms.send.assert_any_call(ClearAssetStateByName(name=self.ASSET_NAME)) From 57437c6650369039b231ca54cd35d111cbc44d51 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Fri, 15 May 2026 18:43:03 +0530 Subject: [PATCH 11/20] ultimate undo of things --- .../execution_api/datamodels/asset_state.py | 13 ---- .../execution_api/datamodels/task_state.py | 13 ---- .../execution_api/routes/asset_state.py | 25 ------- .../execution_api/routes/task_state.py | 22 ------ .../execution_api/versions/v2026_06_16.py | 3 - .../versions/head/test_task_state.py | 28 -------- task-sdk/src/airflow/sdk/api/client.py | 31 -------- .../airflow/sdk/api/datamodels/_generated.py | 46 ------------ task-sdk/src/airflow/sdk/exceptions.py | 2 - .../src/airflow/sdk/execution_time/comms.py | 50 ------------- .../airflow/sdk/execution_time/supervisor.py | 26 ------- task-sdk/tests/task_sdk/api/test_client.py | 69 ------------------ .../execution_time/test_supervisor.py | 58 --------------- .../execution_time/test_task_runner.py | 72 ++----------------- 14 files changed, 4 insertions(+), 454 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/asset_state.py b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/asset_state.py index 35852686b93fd..ec773201c7e2f 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/asset_state.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/asset_state.py @@ -30,16 +30,3 @@ class AssetStatePutBody(StrictBaseModel): """Request body for setting an asset state value.""" value: str - - -class AssetStateItem(StrictBaseModel): - """Asset state key/value pair returned by the list endpoint.""" - - key: str - value: str - - -class AssetStateListResponse(StrictBaseModel): - """All asset state entries for an asset.""" - - items: list[AssetStateItem] diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/task_state.py b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/task_state.py index 93b5eadbbfcd8..3200f3177af35 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/task_state.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/task_state.py @@ -30,16 +30,3 @@ class TaskStatePutBody(StrictBaseModel): """Request body for setting a task state value.""" value: str - - -class TaskStateItem(StrictBaseModel): - """Task state key/value pair returned by the list endpoint.""" - - key: str - value: str - - -class TaskStateListResponse(StrictBaseModel): - """All task state entries for a task instance.""" - - items: list[TaskStateItem] diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/asset_state.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/asset_state.py index 46a6081988029..2351caa6dfaf0 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/asset_state.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/asset_state.py @@ -37,14 +37,11 @@ from airflow._shared.state import AssetScope from airflow.api_fastapi.common.db.common import SessionDep from airflow.api_fastapi.execution_api.datamodels.asset_state import ( - AssetStateItem, - AssetStateListResponse, AssetStatePutBody, AssetStateResponse, ) from airflow.api_fastapi.execution_api.security import ExecutionAPIRoute from airflow.models.asset import AssetModel -from airflow.models.asset_state import AssetStateModel from airflow.state import get_state_backend # TODO(AIP-103): enforce that the requesting task is registered with the asset @@ -180,25 +177,3 @@ def clear_asset_state_by_uri( """Delete all state keys for an asset by asset URI.""" asset_id = _resolve_asset_id_by_uri(uri, session) get_state_backend().clear(AssetScope(asset_id=asset_id), session=session) - - -@router.get("/by-name/all") -def list_asset_state_by_name( - name: Annotated[str, Query(min_length=1)], - session: SessionDep, -) -> AssetStateListResponse: - """List all key/value pairs for an asset identified by name.""" - asset_id = _resolve_asset_id_by_name(name, session) - rows = session.scalars(select(AssetStateModel).where(AssetStateModel.asset_id == asset_id)).all() - return AssetStateListResponse(items=[AssetStateItem(key=r.key, value=r.value) for r in rows]) - - -@router.get("/by-uri/all") -def list_asset_state_by_uri( - uri: Annotated[str, Query(min_length=1)], - session: SessionDep, -) -> AssetStateListResponse: - """List all key/value pairs for an asset identified by URI.""" - asset_id = _resolve_asset_id_by_uri(uri, session) - rows = session.scalars(select(AssetStateModel).where(AssetStateModel.asset_id == asset_id)).all() - return AssetStateListResponse(items=[AssetStateItem(key=r.key, value=r.value) for r in rows]) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_state.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_state.py index 182f950e11260..acdaa8c6a24ee 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_state.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_state.py @@ -21,19 +21,15 @@ from cadwyn import VersionedAPIRouter from fastapi import HTTPException, Path, Query, Security, status -from sqlalchemy import select from sqlalchemy.orm import Session from airflow._shared.state import TaskScope from airflow.api_fastapi.common.db.common import SessionDep from airflow.api_fastapi.execution_api.datamodels.task_state import ( - TaskStateItem, - TaskStateListResponse, TaskStatePutBody, TaskStateResponse, ) from airflow.api_fastapi.execution_api.security import ExecutionAPIRoute, require_auth -from airflow.models.task_state import TaskStateModel from airflow.models.taskinstance import TaskInstance as TI from airflow.state import get_state_backend @@ -130,21 +126,3 @@ def clear_task_state( """ scope = _get_task_scope_for_ti(task_instance_id, session) get_state_backend().clear(scope, all_map_indices=all_map_indices, session=session) - - -@router.get("/{task_instance_id}") -def list_task_state( - task_instance_id: UUID, - session: SessionDep, -) -> TaskStateListResponse: - """List all key/value pairs for a task instance.""" - scope = _get_task_scope_for_ti(task_instance_id, session) - rows = session.scalars( - select(TaskStateModel).where( - TaskStateModel.dag_id == scope.dag_id, - TaskStateModel.run_id == scope.run_id, - TaskStateModel.task_id == scope.task_id, - TaskStateModel.map_index == scope.map_index, - ) - ).all() - return TaskStateListResponse(items=[TaskStateItem(key=r.key, value=r.value) for r in rows]) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_06_16.py b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_06_16.py index 3f24e1b8f8622..779612bbde134 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_06_16.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_06_16.py @@ -65,17 +65,14 @@ class AddStateEndpoints(VersionChange): description = __doc__ instructions_to_migrate_to_previous_version = ( - endpoint("/state/ti/{task_instance_id}", ["GET"]).didnt_exist, endpoint("/state/ti/{task_instance_id}/{key}", ["GET"]).didnt_exist, endpoint("/state/ti/{task_instance_id}/{key}", ["PUT"]).didnt_exist, endpoint("/state/ti/{task_instance_id}/{key}", ["DELETE"]).didnt_exist, endpoint("/state/ti/{task_instance_id}", ["DELETE"]).didnt_exist, - endpoint("/state/asset/by-name/all", ["GET"]).didnt_exist, endpoint("/state/asset/by-name/value", ["GET"]).didnt_exist, endpoint("/state/asset/by-name/value", ["PUT"]).didnt_exist, endpoint("/state/asset/by-name/value", ["DELETE"]).didnt_exist, endpoint("/state/asset/by-name/clear", ["DELETE"]).didnt_exist, - endpoint("/state/asset/by-uri/all", ["GET"]).didnt_exist, endpoint("/state/asset/by-uri/value", ["GET"]).didnt_exist, endpoint("/state/asset/by-uri/value", ["PUT"]).didnt_exist, endpoint("/state/asset/by-uri/value", ["DELETE"]).didnt_exist, diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_state.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_state.py index f5abf1f9b8f13..8a66a0a23c739 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_state.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_state.py @@ -51,34 +51,6 @@ def _api_url(ti_id, key: str | None = None) -> str: return f"{base}/{key}" if key else base -class TestListTaskState: - def test_list_returns_all_keys(self, client: TestClient, create_task_instance: CreateTaskInstance): - ti = create_task_instance() - client.put(_api_url(ti.id, "job_id"), json={"value": "spark_001"}) - client.put(_api_url(ti.id, "checkpoint"), json={"value": "step_3"}) - - response = client.get(_api_url(ti.id)) - - assert response.status_code == 200 - items = {item["key"]: item["value"] for item in response.json()["items"]} - assert items == {"job_id": "spark_001", "checkpoint": "step_3"} - - def test_list_returns_empty_when_no_state( - self, client: TestClient, create_task_instance: CreateTaskInstance - ): - ti = create_task_instance() - - response = client.get(_api_url(ti.id)) - - assert response.status_code == 200 - assert response.json() == {"items": []} - - def test_list_missing_ti_returns_404(self, client: TestClient): - response = client.get(_api_url(uuid4())) - - assert response.status_code == 404 - - class TestGetTaskState: def test_get_returns_value(self, client: TestClient, create_task_instance: CreateTaskInstance): ti = create_task_instance() diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index 06ae59fb515fd..a1928357e2b21 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -46,7 +46,6 @@ API_VERSION, AssetEventsResponse, AssetResponse, - AssetStateListResponse, AssetStatePutBody, AssetStateResponse, ConnectionResponse, @@ -61,7 +60,6 @@ PrevSuccessfulDagRunResponse, TaskBreadcrumbsResponse, TaskInstanceState, - TaskStateListResponse, TaskStatePutBody, TaskStateResponse, TaskStatesResponse, @@ -705,18 +703,6 @@ def delete(self, ti_id: uuid.UUID, key: str) -> OKResponse: self.client.delete(f"state/ti/{ti_id}/{key}") return OKResponse(ok=True) - def list(self, ti_id: uuid.UUID) -> TaskStateListResponse | ErrorResponse: - """Return all key/stored-value pairs for a task instance.""" - try: - resp = self.client.get(f"state/ti/{ti_id}") - except ServerResponseError as e: - if e.response.status_code == HTTPStatus.NOT_FOUND: - log.debug("Task states cannot be retrieved for task instance", ti_id=ti_id) - return ErrorResponse(error=ErrorType.TASK_STATES_NOT_FOUND, detail={"ti_id": ti_id}) - raise - - return TaskStateListResponse.model_validate_json(resp.read()) - def clear(self, ti_id: uuid.UUID, all_map_indices: bool = False) -> OKResponse: """Clear all task state keys for a task instance via the API server.""" params = {"all_map_indices": "true"} if all_map_indices else {} @@ -771,23 +757,6 @@ def delete(self, key: str, *, name: str | None = None, uri: str | None = None) - self.client.delete(endpoint, params=params) return OKResponse(ok=True) - def list( - self, *, name: str | None = None, uri: str | None = None - ) -> AssetStateListResponse | ErrorResponse: - """Return all key/stored-value pairs for an asset via the API server.""" - endpoint, params = self._resolve_endpoint("all", name=name, uri=uri) - try: - resp = self.client.get(endpoint, params=params) - except ServerResponseError as e: - if e.response.status_code == HTTPStatus.NOT_FOUND: - log.debug("Asset state cannot be retrieved for asset", name=name, uri=uri) - return ErrorResponse( - error=ErrorType.ASSET_STATES_NOT_FOUND, detail={"name": name, "uri": uri} - ) - raise - - return AssetStateListResponse.model_validate_json(resp.read()) - def clear(self, *, name: str | None = None, uri: str | None = None) -> OKResponse: """Clear all state keys for an asset via the API server.""" endpoint, params = self._resolve_endpoint("clear", name=name, uri=uri) diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py index 049d0cdb46b89..9f1dadeef51ca 100644 --- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -63,29 +63,6 @@ class AssetProfile(BaseModel): type: Annotated[str, Field(title="Type")] -class AssetStateItem(BaseModel): - """ - Asset state key/value pair returned by the list endpoint. - """ - - model_config = ConfigDict( - extra="forbid", - ) - key: Annotated[str, Field(title="Key")] - value: Annotated[str, Field(title="Value")] - - -class AssetStateListResponse(BaseModel): - """ - All asset state entries for an asset. - """ - - model_config = ConfigDict( - extra="forbid", - ) - items: Annotated[list[AssetStateItem], Field(title="Items")] - - class AssetStatePutBody(BaseModel): """ Request body for setting an asset state value. @@ -390,29 +367,6 @@ class TaskInstanceState(str, Enum): DEFERRED = "deferred" -class TaskStateItem(BaseModel): - """ - Task state key/value pair returned by the list endpoint. - """ - - model_config = ConfigDict( - extra="forbid", - ) - key: Annotated[str, Field(title="Key")] - value: Annotated[str, Field(title="Value")] - - -class TaskStateListResponse(BaseModel): - """ - All task state entries for a task instance. - """ - - model_config = ConfigDict( - extra="forbid", - ) - items: Annotated[list[TaskStateItem], Field(title="Items")] - - class TaskStatePutBody(BaseModel): """ Request body for setting a task state value. diff --git a/task-sdk/src/airflow/sdk/exceptions.py b/task-sdk/src/airflow/sdk/exceptions.py index 99a8e5ce27477..b0ff82be293e7 100644 --- a/task-sdk/src/airflow/sdk/exceptions.py +++ b/task-sdk/src/airflow/sdk/exceptions.py @@ -81,9 +81,7 @@ class ErrorType(enum.Enum): XCOM_NOT_FOUND = "XCOM_NOT_FOUND" ASSET_NOT_FOUND = "ASSET_NOT_FOUND" TASK_STATE_NOT_FOUND = "TASK_STATE_NOT_FOUND" - TASK_STATES_NOT_FOUND = "TASK_STATES_NOT_FOUND" ASSET_STATE_NOT_FOUND = "ASSET_STATE_NOT_FOUND" - ASSET_STATES_NOT_FOUND = "ASSET_STATE_NOT_FOUND" DAGRUN_ALREADY_EXISTS = "DAGRUN_ALREADY_EXISTS" GENERIC_ERROR = "GENERIC_ERROR" API_SERVER_ERROR = "API_SERVER_ERROR" diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index e622099283f96..a30872a6a54f0 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -69,8 +69,6 @@ AssetEventResponse, AssetEventsResponse, AssetResponse, - AssetStateItem, - AssetStateListResponse, AssetStateResponse, BundleInfo, ConnectionResponse, @@ -84,8 +82,6 @@ TaskBreadcrumbsResponse, TaskInstance, TaskInstanceState, - TaskStateItem, - TaskStateListResponse, TaskStateResponse, TaskStatesResponse, TIDeferredStatePayload, @@ -772,26 +768,6 @@ def from_api_response(cls, dag_response: DagResponse) -> DagResult: return cls(**dag_response.model_dump(exclude_defaults=True), type="DagResult") -class AllTaskStateResult(TaskStateListResponse): - """Response to GetAllTaskState: all key/value pairs for a task instance.""" - - type: Literal["AllTaskStateResult"] = "AllTaskStateResult" - - @classmethod - def from_api_response(cls, items: list[TaskStateItem]) -> AllTaskStateResult: - return cls(items=items, type="AllTaskStateResult") - - -class AllAssetStateResult(AssetStateListResponse): - """Response to GetAllAssetStateByName/Uri: all key/value pairs for an asset.""" - - type: Literal["AllAssetStateResult"] = "AllAssetStateResult" - - @classmethod - def from_api_response(cls, items: list[AssetStateItem]) -> AllAssetStateResult: - return cls(items=items, type="AllAssetStateResult") - - ToTask = Annotated[ AssetResult | AssetsByAliasResult @@ -808,8 +784,6 @@ def from_api_response(cls, items: list[AssetStateItem]) -> AllAssetStateResult: | SentFDs | StartupDetails | TaskRescheduleStartDate - | AllAssetStateResult - | AllTaskStateResult | TaskStateResult | TICount | TaskBreadcrumbsResult @@ -965,13 +939,6 @@ class ClearTaskState(BaseModel): type: Literal["ClearTaskState"] = "ClearTaskState" -class GetAllTaskState(BaseModel): - """Fetch all key/stored-value pairs for a task instance.""" - - ti_id: UUID - type: Literal["GetAllTaskState"] = "GetAllTaskState" - - class GetAssetStateByName(BaseModel): name: str key: str @@ -998,20 +965,6 @@ class SetAssetStateByUri(BaseModel): type: Literal["SetAssetStateByUri"] = "SetAssetStateByUri" -class GetAllAssetStateByName(BaseModel): - """Fetch all key/value pairs for an asset identified by name.""" - - name: str - type: Literal["GetAllAssetStateByName"] = "GetAllAssetStateByName" - - -class GetAllAssetStateByUri(BaseModel): - """Fetch all key/value pairs for an asset identified by URI.""" - - uri: str - type: Literal["GetAllAssetStateByUri"] = "GetAllAssetStateByUri" - - class DeleteAssetStateByName(BaseModel): name: str key: str @@ -1260,9 +1213,6 @@ class GetDag(BaseModel): | GetPreviousDagRun | GetPreviousTI | GetTaskRescheduleStartDate - | GetAllAssetStateByName - | GetAllAssetStateByUri - | GetAllTaskState | GetTaskState | GetTICount | GetTaskBreadcrumbs diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index a2dc5fe83caf1..3e6236c578658 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -62,8 +62,6 @@ from airflow.sdk.exceptions import ErrorType from airflow.sdk.execution_time import comms from airflow.sdk.execution_time.comms import ( - AllAssetStateResult, - AllTaskStateResult, AssetEventsResult, AssetResult, AssetStateResult, @@ -82,9 +80,6 @@ DeleteVariable, DeleteXCom, ErrorResponse, - GetAllAssetStateByName, - GetAllAssetStateByUri, - GetAllTaskState, GetAssetByName, GetAssetByUri, GetAssetEventByAsset, @@ -1664,27 +1659,6 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id: if isinstance(task_state, ErrorResponse) else TaskStateResult.from_task_state_response(task_state) ) - elif isinstance(msg, GetAllTaskState): - result = self.client.task_state.list(msg.ti_id) - resp = ( - result - if isinstance(result, ErrorResponse) - else AllTaskStateResult.from_api_response(result.items) - ) - elif isinstance(msg, GetAllAssetStateByName): - result = self.client.asset_state.list(name=msg.name) - resp = ( - result - if isinstance(result, ErrorResponse) - else AllAssetStateResult.from_api_response(result.items) - ) - elif isinstance(msg, GetAllAssetStateByUri): - result = self.client.asset_state.list(uri=msg.uri) - resp = ( - result - if isinstance(result, ErrorResponse) - else AllAssetStateResult.from_api_response(result.items) - ) elif isinstance(msg, SetTaskState): self.client.task_state.set(msg.ti_id, msg.key, msg.value) resp = OKResponse(ok=True) diff --git a/task-sdk/tests/task_sdk/api/test_client.py b/task-sdk/tests/task_sdk/api/test_client.py index eab824b40f268..a179ff08436b2 100644 --- a/task-sdk/tests/task_sdk/api/test_client.py +++ b/task-sdk/tests/task_sdk/api/test_client.py @@ -36,8 +36,6 @@ from airflow.sdk.api.datamodels._generated import ( AssetEventsResponse, AssetResponse, - AssetStateItem, - AssetStateListResponse, AssetStateResponse, ConnectionResponse, DagResponse, @@ -46,8 +44,6 @@ HITLDetailRequest, HITLDetailResponse, HITLUser, - TaskStateItem, - TaskStateListResponse, TaskStateResponse, TerminalTIState, VariableResponse, @@ -1808,34 +1804,6 @@ def handle_request(request: httpx.Request) -> httpx.Response: result = client.task_state.clear(ti_id=self.TI_ID, all_map_indices=True) assert result == OKResponse(ok=True) - def test_list_returns_key_value_pairs(self): - def handle_request(request: httpx.Request) -> httpx.Response: - assert request.method == "GET" - assert request.url.path == f"/state/ti/{self.TI_ID}" - return httpx.Response( - status_code=200, - json={ - "items": [{"key": "job_id", "value": "app_001"}, {"key": "checkpoint", "value": "step_3"}] - }, - ) - - client = make_client(transport=httpx.MockTransport(handle_request)) - result = client.task_state.list(ti_id=self.TI_ID) - assert isinstance(result, TaskStateListResponse) - assert result.items == [ - TaskStateItem(key="job_id", value="app_001"), - TaskStateItem(key="checkpoint", value="step_3"), - ] - - def test_list_returns_empty_when_no_state(self): - def handle_request(request: httpx.Request) -> httpx.Response: - return httpx.Response(status_code=200, json={"items": []}) - - client = make_client(transport=httpx.MockTransport(handle_request)) - result = client.task_state.list(ti_id=self.TI_ID) - assert isinstance(result, TaskStateListResponse) - assert result.items == [] - class TestAssetStateOperations: def test_get_by_name_success(self): @@ -1952,40 +1920,3 @@ def handle_request(request: httpx.Request) -> httpx.Response: client = make_client(transport=httpx.MockTransport(handle_request)) result = client.asset_state.clear(uri="s3://bucket/key") assert result == OKResponse(ok=True) - - def test_list_all_by_name_returns_items(self): - def handle_request(request: httpx.Request) -> httpx.Response: - assert request.method == "GET" - assert request.url.path == "/state/asset/by-name/all" - assert request.url.params["name"] == "test_asset" - return httpx.Response( - status_code=200, - json={ - "items": [ - {"key": "watermark", "value": "2026-05-01"}, - {"key": "file_count", "value": "42"}, - ] - }, - ) - - client = make_client(transport=httpx.MockTransport(handle_request)) - result = client.asset_state.list(name="test_asset") - assert isinstance(result, AssetStateListResponse) - assert result.items == [ - AssetStateItem(key="watermark", value="2026-05-01"), - AssetStateItem(key="file_count", value="42"), - ] - - def test_list_all_by_uri_returns_items(self): - def handle_request(request: httpx.Request) -> httpx.Response: - assert request.method == "GET" - assert request.url.path == "/state/asset/by-uri/all" - assert request.url.params["uri"] == "s3://bucket/key" - return httpx.Response( - status_code=200, json={"items": [{"key": "watermark", "value": "2026-05-01"}]} - ) - - client = make_client(transport=httpx.MockTransport(handle_request)) - result = client.asset_state.list(uri="s3://bucket/key") - assert isinstance(result, AssetStateListResponse) - assert result.items == [AssetStateItem(key="watermark", value="2026-05-01")] diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index 9c97a2f26b535..b54477b7769bb 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -60,20 +60,16 @@ AssetEventResponse, AssetProfile, AssetResponse, - AssetStateItem, DagRun, DagRunState, DagRunType, PreviousTIResponse, TaskInstance, TaskInstanceState, - TaskStateItem, ) from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType, TaskAlreadyRunningError from airflow.sdk.execution_time import supervisor, task_runner from airflow.sdk.execution_time.comms import ( - AllAssetStateResult, - AllTaskStateResult, AssetEventsResult, AssetResult, AssetsByAliasResult, @@ -95,9 +91,6 @@ DeleteXCom, DRCount, ErrorResponse, - GetAllAssetStateByName, - GetAllAssetStateByUri, - GetAllTaskState, GetAssetByName, GetAssetByUri, GetAssetEventByAsset, @@ -2728,24 +2721,6 @@ class RequestTestCase: ), expected_body={"value": "spark_app_001", "type": "TaskStateResult"}, ), - RequestTestCase( - message=GetAllTaskState(ti_id=TI_ID), - test_id="get_all_task_state", - client_mock=ClientMock( - method_path="task_state.list", - args=(TI_ID,), - response=AllTaskStateResult.from_api_response( - [ - TaskStateItem(key="job_id", value="app_001"), - TaskStateItem(key="checkpoint", value="step_3"), - ] - ), - ), - expected_body={ - "items": [{"key": "job_id", "value": "app_001"}, {"key": "checkpoint", "value": "step_3"}], - "type": "AllTaskStateResult", - }, - ), RequestTestCase( message=SetTaskState(ti_id=TI_ID, key="job_id", value="spark_app_001"), test_id="set_task_state", @@ -2810,39 +2785,6 @@ class RequestTestCase: ), expected_body={"value": "2026-04-30T00:00:00Z", "type": "AssetStateResult"}, ), - RequestTestCase( - message=GetAllAssetStateByName(name="debug_watcher_asset"), - test_id="get_all_asset_state_by_name", - client_mock=ClientMock( - method_path="asset_state.list", - kwargs={"name": "debug_watcher_asset"}, - response=AllAssetStateResult.from_api_response( - [ - AssetStateItem(key="watermark", value="2026-05-01"), - AssetStateItem(key="file_count", value="42"), - ] - ), - ), - expected_body={ - "items": [{"key": "watermark", "value": "2026-05-01"}, {"key": "file_count", "value": "42"}], - "type": "AllAssetStateResult", - }, - ), - RequestTestCase( - message=GetAllAssetStateByUri(uri="s3://bucket/key"), - test_id="get_all_asset_state_by_uri", - client_mock=ClientMock( - method_path="asset_state.list", - kwargs={"uri": "s3://bucket/key"}, - response=AllAssetStateResult.from_api_response( - [AssetStateItem(key="watermark", value="2026-05-01")] - ), - ), - expected_body={ - "items": [{"key": "watermark", "value": "2026-05-01"}], - "type": "AllAssetStateResult", - }, - ), RequestTestCase( message=SetAssetStateByName( name="debug_watcher_asset", key="watermark", value="2026-04-30T00:00:00Z" diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 0f135a6e21f75..c1b15b5e111a2 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -5105,66 +5105,6 @@ def execute(self, context): SetAssetStateByName(name="asset_b", key="watermark_b", value="2026-05-02") ) - @conf_vars({("state_store", "clear_on_success"): "True"}) - def test_clear_on_success_calls_clear_when_worker_backend_configured( - self, create_runtime_ti, mock_supervisor_comms - ): - """When clear_on_success=True and a worker backend is configured, clear() is called on task success.""" - mock_backend = mock.MagicMock() - - class MyOperator(BaseOperator): - def execute(self, context): - pass - - task = MyOperator(task_id="t") - runtime_ti = create_runtime_ti(task=task) - - with mock.patch( - "airflow.sdk.execution_time.task_runner._get_worker_state_backend", return_value=mock_backend - ): - run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) - - mock_supervisor_comms.send.assert_any_call(ClearTaskState(ti_id=runtime_ti.id, all_map_indices=False)) - - @conf_vars({("state_store", "clear_on_success"): "True"}) - def test_clear_on_success_skips_when_no_worker_backend(self, create_runtime_ti, mock_supervisor_comms): - """When clear_on_success=True but no worker backend configured, clear() is not called.""" - - class MyOperator(BaseOperator): - def execute(self, context): - pass - - task = MyOperator(task_id="t") - runtime_ti = create_runtime_ti(task=task) - - with mock.patch( - "airflow.sdk.execution_time.task_runner._get_worker_state_backend", return_value=None - ): - run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) - - calls = [str(c) for c in mock_supervisor_comms.send.call_args_list] - assert not any("ClearTaskState" in c for c in calls) - - @conf_vars({("state_store", "clear_on_success"): "False"}) - def test_clear_on_success_disabled_does_not_call_clear(self, create_runtime_ti, mock_supervisor_comms): - """When clear_on_success=False, clear() is not called even if a worker backend is configured.""" - mock_backend = mock.MagicMock() - - class MyOperator(BaseOperator): - def execute(self, context): - pass - - task = MyOperator(task_id="t") - runtime_ti = create_runtime_ti(task=task) - - with mock.patch( - "airflow.sdk.execution_time.task_runner._get_worker_state_backend", return_value=mock_backend - ): - run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) - - calls = [str(c) for c in mock_supervisor_comms.send.call_args_list] - assert not any("ClearTaskState" in c for c in calls) - def test_asset_state_set_sends_reference_via_custom_backend( self, create_runtime_ti, mock_supervisor_comms ): @@ -5205,13 +5145,11 @@ def execute(self, context): task = MyOperator(task_id="t") runtime_ti = create_runtime_ti(task=task) - mock_supervisor_comms.send.side_effect = [ - OKResponse(ok=True), # SetTaskState - SucceedTask(...), # finalize - ] + mock_supervisor_comms.send.side_effect = TestTaskInstanceStateOperations._watcher_side_effect mock_backend = mock.MagicMock() - mock_backend.serialize_task_state_value.return_value = f"mem://{runtime_ti.id}/job_id" + ref = f"mem://{runtime_ti.id}/job_id" + mock_backend.serialize_task_state_value.return_value = ref with mock.patch( "airflow.sdk.execution_time.context._get_worker_state_backend", return_value=mock_backend @@ -5221,6 +5159,4 @@ def execute(self, context): mock_backend.serialize_task_state_value.assert_called_once_with( value="app_001", key="job_id", ti_id=str(runtime_ti.id) ) - mock_supervisor_comms.send.assert_any_call( - SetTaskState(ti_id=runtime_ti.id, key="job_id", value=f"mem://{runtime_ti.id}/job_id") - ) + mock_supervisor_comms.send.assert_any_call(SetTaskState(ti_id=runtime_ti.id, key="job_id", value=ref)) From 6e7a00eacfa60a1474efe04a660c2f89f055c507 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Fri, 15 May 2026 18:50:11 +0530 Subject: [PATCH 12/20] cleaning up things --- airflow-core/tests/unit/dag_processing/test_processor.py | 5 ----- airflow-core/tests/unit/jobs/test_triggerer_job.py | 5 ----- task-sdk/src/airflow/sdk/_shared/state | 1 - 3 files changed, 11 deletions(-) delete mode 120000 task-sdk/src/airflow/sdk/_shared/state diff --git a/airflow-core/tests/unit/dag_processing/test_processor.py b/airflow-core/tests/unit/dag_processing/test_processor.py index f6be44e75224d..d2ac085b0b736 100644 --- a/airflow-core/tests/unit/dag_processing/test_processor.py +++ b/airflow-core/tests/unit/dag_processing/test_processor.py @@ -1985,9 +1985,6 @@ def get_type_names(union_type): "DeleteAssetStateByUri", "ClearAssetStateByName", "ClearAssetStateByUri", - "GetAllAssetStateByName", - "GetAllAssetStateByUri", - "GetAllTaskState", } in_task_runner_but_not_in_dag_processing_process = { @@ -2010,8 +2007,6 @@ def get_type_names(union_type): # AIP-103 task/asset state results — worker-only responses to the above messages. "TaskStateResult", "AssetStateResult", - "AllAssetStateResult", - "AllTaskStateResult", } supervisor_diff = supervisor_types - manager_types - in_supervisor_but_not_in_manager diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py b/airflow-core/tests/unit/jobs/test_triggerer_job.py index b5c186201ac57..0501783b992d2 100644 --- a/airflow-core/tests/unit/jobs/test_triggerer_job.py +++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py @@ -1809,9 +1809,6 @@ def get_type_names(union_type): "DeleteAssetStateByUri", "ClearAssetStateByName", "ClearAssetStateByUri", - "GetAllAssetStateByName", - "GetAllAssetStateByUri", - "GetAllTaskState", } in_task_but_not_in_trigger_runner = { @@ -1836,8 +1833,6 @@ def get_type_names(union_type): # AIP-103 task/asset state results — worker-only responses to the above messages. "TaskStateResult", "AssetStateResult", - "AllAssetStateResult", - "AllTaskStateResult", } supervisor_diff = ( diff --git a/task-sdk/src/airflow/sdk/_shared/state b/task-sdk/src/airflow/sdk/_shared/state deleted file mode 120000 index cb2f9414b9c9e..0000000000000 --- a/task-sdk/src/airflow/sdk/_shared/state +++ /dev/null @@ -1 +0,0 @@ -/Users/amoghdesai/Documents/OSS/repos/airflow/shared/state/src/airflow_shared/state \ No newline at end of file From 548984480d9c20c6f3d82ec8d2d291a34528b321 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Fri, 15 May 2026 19:06:05 +0530 Subject: [PATCH 13/20] cleaning up things --- task-sdk/src/airflow/sdk/_shared/state | 1 + task-sdk/src/airflow/sdk/execution_time/task_runner.py | 2 +- task-sdk/src/airflow/sdk/state.py | 6 +++++- 3 files changed, 7 insertions(+), 2 deletions(-) create mode 120000 task-sdk/src/airflow/sdk/_shared/state diff --git a/task-sdk/src/airflow/sdk/_shared/state b/task-sdk/src/airflow/sdk/_shared/state new file mode 120000 index 0000000000000..752da6322068e --- /dev/null +++ b/task-sdk/src/airflow/sdk/_shared/state @@ -0,0 +1 @@ +../../../../../shared/state/src/airflow_shared/state \ No newline at end of file diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 50ad7045f7d3b..2ca443b4d8e86 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -43,7 +43,6 @@ from airflow.dag_processing.bundles.base import BaseDagBundle, BundleVersionLock from airflow.dag_processing.bundles.manager import DagBundlesManager from airflow.sdk._shared.observability.metrics import stats -from airflow.sdk._shared.state import TaskScope from airflow.sdk._shared.template_rendering import truncate_rendered_value from airflow.sdk.api.client import get_hostname, getuser from airflow.sdk.api.datamodels._generated import ( @@ -129,6 +128,7 @@ from airflow.sdk.execution_time.xcom import XCom from airflow.sdk.listener import get_listener_manager from airflow.sdk.observability.metrics import stats_utils +from airflow.sdk.state import TaskScope from airflow.sdk.timezone import coerce_datetime if TYPE_CHECKING: diff --git a/task-sdk/src/airflow/sdk/state.py b/task-sdk/src/airflow/sdk/state.py index 21f9acd54850d..ac2a1126fe4bd 100644 --- a/task-sdk/src/airflow/sdk/state.py +++ b/task-sdk/src/airflow/sdk/state.py @@ -18,4 +18,8 @@ from __future__ import annotations -from airflow.sdk._shared.state import BaseStateBackend as BaseStateBackend +from airflow.sdk._shared.state import ( + AssetScope as AssetScope, + BaseStateBackend as BaseStateBackend, + TaskScope as TaskScope, +) From dd5355194ace6f4ab603d668feee2ef7efe80d67 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Sat, 16 May 2026 20:22:05 +0530 Subject: [PATCH 14/20] fixing tests --- task-sdk/tests/task_sdk/docs/test_public_api.py | 1 + .../tests/task_sdk/execution_time/test_task_runner.py | 10 ++++++++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/task-sdk/tests/task_sdk/docs/test_public_api.py b/task-sdk/tests/task_sdk/docs/test_public_api.py index 98391927f8a4c..a21424ea101f6 100644 --- a/task-sdk/tests/task_sdk/docs/test_public_api.py +++ b/task-sdk/tests/task_sdk/docs/test_public_api.py @@ -65,6 +65,7 @@ def test_airflow_sdk_no_unexpected_exports(): "providers_manager_runtime", "lineage", "types", + "state", } unexpected = actual - public - ignore assert not unexpected, f"Unexpected exports in airflow.sdk: {sorted(unexpected)}" diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index c1b15b5e111a2..96c0d591f2800 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -55,6 +55,7 @@ timezone, ) from airflow.sdk._shared.observability.metrics.base_stats_logger import StatsLogger +from airflow.sdk._shared.state import TaskScope from airflow.sdk.api.datamodels._generated import ( AssetProfile, AssetResponse, @@ -1781,7 +1782,9 @@ def test_get_context_without_ti_context_from_server(self, mocked_parse, make_ti_ "run_id": "test_run", "task": task, "task_instance": runtime_ti, - "task_state": TaskStateAccessor(ti_id=ti_id), + "task_state": TaskStateAccessor( + ti_id=ti_id, scope=TaskScope(dag_id=dag_id, run_id="test_run", task_id="hello") + ), "ti": runtime_ti, } @@ -1827,7 +1830,10 @@ def test_get_context_with_ti_context_from_server(self, create_runtime_ti, mock_s "run_id": "test_run", "task": task, "task_instance": runtime_ti, - "task_state": TaskStateAccessor(ti_id=runtime_ti.id), + "task_state": TaskStateAccessor( + ti_id=runtime_ti.id, + scope=TaskScope(dag_id=runtime_ti.dag_id, run_id="test_run", task_id="hello"), + ), "ti": runtime_ti, "dag_run": dr, "data_interval_end": timezone.datetime(2024, 12, 1, 1, 0, 0), From 6db9bc395942040e1854aba12130872e5f228d87 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Tue, 19 May 2026 11:50:00 +0530 Subject: [PATCH 15/20] comments from wei --- shared/state/src/airflow_shared/state/__init__.py | 5 ++++- task-sdk/src/airflow/sdk/execution_time/context.py | 7 ++++--- task-sdk/tests/task_sdk/execution_time/test_context.py | 4 ++-- task-sdk/tests/task_sdk/execution_time/test_task_runner.py | 2 +- 4 files changed, 11 insertions(+), 7 deletions(-) diff --git a/shared/state/src/airflow_shared/state/__init__.py b/shared/state/src/airflow_shared/state/__init__.py index 68b664e00eab1..4dc820fa31605 100644 --- a/shared/state/src/airflow_shared/state/__init__.py +++ b/shared/state/src/airflow_shared/state/__init__.py @@ -193,13 +193,16 @@ def deserialize_task_state_value(self, stored: str) -> str: """ return stored - def serialize_asset_state_value(self, *, value: str, key: str, asset_name: str) -> str: + def serialize_asset_state_value(self, *, value: str, key: str, asset_ref: str) -> str: """ Serialize an asset state value before it is sent to the Execution API for db persistence. Called by ``AssetStateAccessor.set()`` on the worker. The return value is what gets stored in the DB — typically a reference path rather than the actual value. Default: return ``value`` unchanged. + + ``asset_ref`` is either the asset name or URI, depending on how the accessor was + constructed. It may be a URI string if the task inlet was declared as ``AssetUriRef``. """ return value diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py b/task-sdk/src/airflow/sdk/execution_time/context.py index 8f5c02286b5b5..c5552ff696e7a 100644 --- a/task-sdk/src/airflow/sdk/execution_time/context.py +++ b/task-sdk/src/airflow/sdk/execution_time/context.py @@ -65,6 +65,7 @@ ReceiveMsgType, VariableResult, ) + from airflow.sdk.state import BaseStateBackend from airflow.sdk.types import OutletEventAccessorsProtocol @@ -436,7 +437,7 @@ def get(self, key, default: Any = NOTSET) -> Any: @cache -def _get_worker_state_backend(): +def _get_worker_state_backend() -> BaseStateBackend | None: """ Return the configured worker-side state backend, instantiated once and cached. @@ -590,9 +591,9 @@ def set(self, key: str, value: str) -> None: # if custom backend is configured, store the value on the custom backend, and return the reference # to the stored value to store in the DB backend = _get_worker_state_backend() - asset_name = self._name or self._uri or "" + asset_ref = self._name or self._uri or "" stored = ( - backend.serialize_asset_state_value(value=value, key=key, asset_name=asset_name) + backend.serialize_asset_state_value(value=value, key=key, asset_ref=asset_ref) if backend else value ) diff --git a/task-sdk/tests/task_sdk/execution_time/test_context.py b/task-sdk/tests/task_sdk/execution_time/test_context.py index acfc9418f7c8b..2d6c291555950 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_context.py +++ b/task-sdk/tests/task_sdk/execution_time/test_context.py @@ -1343,8 +1343,8 @@ def deserialize_task_state_value(self, stored: str) -> str: key = stored.rsplit("/", 1)[-1] return self._actual_key_value_store.get(key, stored) - def serialize_asset_state_value(self, *, value: str, key: str, asset_name: str) -> str: - ref = f"mem://{asset_name}/{key}" + def serialize_asset_state_value(self, *, value: str, key: str, asset_ref: str) -> str: + ref = f"mem://{asset_ref}/{key}" self._actual_key_value_store[key] = value self.reference[key] = ref return ref diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 96c0d591f2800..f4070ff71ff3b 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -5134,7 +5134,7 @@ def execute(self, context): run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) mock_backend.serialize_asset_state_value.assert_called_once_with( - value="2026-05-01", key="watermark", asset_name="my_asset" + value="2026-05-01", key="watermark", asset_ref="my_asset" ) mock_supervisor_comms.send.assert_any_call( SetAssetStateByName(name="my_asset", key="watermark", value="mem://my_asset/watermark") From 167826fd03943f6b29f9208a22d6788f8291b2ad Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Tue, 19 May 2026 12:01:40 +0530 Subject: [PATCH 16/20] fixing shared tests --- shared/state/tests/state/test_state.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/shared/state/tests/state/test_state.py b/shared/state/tests/state/test_state.py index b3e11b2833d3b..bbc344df6789b 100644 --- a/shared/state/tests/state/test_state.py +++ b/shared/state/tests/state/test_state.py @@ -107,7 +107,7 @@ def deserialize_task_state_value(self, stored): def test_asset_state_serialize_deserialize_round_trip(self, backend): original = "2026-05-01" serialized = backend.serialize_asset_state_value( - value="2026-05-01", key="watermark", asset_name="my_asset" + value="2026-05-01", key="watermark", asset_ref="my_asset" ) deserialized = backend.deserialize_asset_state_value(serialized) assert deserialized == original @@ -131,7 +131,7 @@ def deserialize_asset_state_value(self, stored): b = MyBackend() assert ( - b.serialize_asset_state_value(value="2026-05-01", key="watermark", asset_name="my_asset") + b.serialize_asset_state_value(value="2026-05-01", key="watermark", asset_ref="my_asset") == "s3://bucket/assets/my_asset/watermark" ) assert ( From 99e4c3b29ffeb02519ada85b932f33e3307f4f13 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Tue, 19 May 2026 12:20:41 +0530 Subject: [PATCH 17/20] uncommenting old comments and cleaning up --- .../src/airflow_shared/state/__init__.py | 9 ++++++-- shared/state/tests/state/test_state.py | 4 ++-- .../src/airflow/sdk/execution_time/context.py | 7 +------ .../airflow/sdk/execution_time/task_runner.py | 13 +++++------- .../execution_time/test_task_runner.py | 21 +++++++++++++++++++ 5 files changed, 36 insertions(+), 18 deletions(-) diff --git a/shared/state/src/airflow_shared/state/__init__.py b/shared/state/src/airflow_shared/state/__init__.py index 0682f76983fec..6c461882bc90a 100644 --- a/shared/state/src/airflow_shared/state/__init__.py +++ b/shared/state/src/airflow_shared/state/__init__.py @@ -50,8 +50,13 @@ class AssetScope: """ Identifies the state namespace for an asset. - We need to store ``name`` or ``uri`` since workers do not have access to the integer ``asset_id``. - (Access method is through AssetStateAccessor). + Server-side backends receive ``asset_id``. Worker-side backends receive ``name`` or ``uri`` + since workers do not have access to the integer ``asset_id``. + + Note: ``name`` and ``uri`` are not guaranteed to be unique over time — if an asset is + deactivated and a new one created with the same name, both share the same ``name`` value. + State for inactive assets is cleaned up by the orphan GC pass; until then, stale rows exist + in the DB but cannot be written to (the Execution API resolver filters to active assets only). """ asset_id: int | None = None diff --git a/shared/state/tests/state/test_state.py b/shared/state/tests/state/test_state.py index bbc344df6789b..3cfbdc61be7c2 100644 --- a/shared/state/tests/state/test_state.py +++ b/shared/state/tests/state/test_state.py @@ -123,8 +123,8 @@ async def aset(self, scope, key, value): ... async def adelete(self, scope, key): ... async def aclear(self, scope, *, all_map_indices=False): ... - def serialize_asset_state_value(self, *, value, key, asset_name): - return f"s3://bucket/assets/{asset_name}/{key}" + def serialize_asset_state_value(self, *, value, key, asset_ref): + return f"s3://bucket/assets/{asset_ref}/{key}" def deserialize_asset_state_value(self, stored): return f"resolved:{stored}" diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py b/task-sdk/src/airflow/sdk/execution_time/context.py index e568d230d718a..a2c0a2cc9a0d6 100644 --- a/task-sdk/src/airflow/sdk/execution_time/context.py +++ b/task-sdk/src/airflow/sdk/execution_time/context.py @@ -444,12 +444,7 @@ def get(self, key, default: Any = NOTSET) -> Any: @cache def _get_worker_state_backend() -> BaseStateBackend | None: - """ - Return the configured worker-side state backend, instantiated once and cached. - - # TODO: rebase / include https://github.com/apache/airflow/pull/66699 once merged - # to also forward ``retention_days`` through the comms layer. - """ + """Return the configured worker-side state backend, instantiated once and cached.""" from airflow.sdk.configuration import get_state_backend return get_state_backend() diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index d4625ea122ca5..35b4e6267e92c 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -119,6 +119,7 @@ TaskStateAccessor, TriggeringAssetEventsAccessor, VariableAccessor, + _get_worker_state_backend, context_get_outlet_events, context_to_airflow_vars, get_previous_dagrun_success, @@ -1458,20 +1459,16 @@ def _handle_current_task_success( stats.incr("operator_successes", tags={**stats_tags, "operator_name": operator}) stats.incr("ti_successes", tags=stats_tags) - # TODO: uncomment below once https://github.com/apache/airflow/pull/66699 is merged - # if conf.getboolean("state_store", "clear_on_success"): - # log.info("Task state will be cleared by the server because clear_on_success is enabled.") - # - # if _get_worker_state_backend() is not None: - # # clear the task state keys for custom state backends configured on worker side - # context["task_state"].clear() - task_outlets = list(_build_asset_profiles(ti.task.outlets)) outlet_events = list(_serialize_outlet_events(context["outlet_events"])) if conf.getboolean("state_store", "clear_on_success"): log.info("Task state will be cleared by the server because clear_on_success is enabled.") + if _get_worker_state_backend() is not None: + # clear the task state keys for custom state backends configured on worker side + context["task_state"].clear() + msg = SucceedTask( end_date=end_date, task_outlets=task_outlets, diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 43eb9cf47dad8..22a29dc9dba8d 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -5269,3 +5269,24 @@ def execute(self, context): value="app_001", key="job_id", ti_id=str(runtime_ti.id) ) mock_supervisor_comms.send.assert_any_call(SetTaskState(ti_id=runtime_ti.id, key="job_id", value=ref)) + + @conf_vars({("state_store", "clear_on_success"): "True"}) + def test_clear_on_success_calls_clear_when_worker_backend_configured( + self, create_runtime_ti, mock_supervisor_comms + ): + """When clear_on_success=True and a worker backend is configured, clear() is called on task success.""" + mock_backend = mock.MagicMock() + + class MyOperator(BaseOperator): + def execute(self, context): + pass + + task = MyOperator(task_id="t") + runtime_ti = create_runtime_ti(task=task) + + with mock.patch( + "airflow.sdk.execution_time.context._get_worker_state_backend", return_value=mock_backend + ): + run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) + + mock_supervisor_comms.send.assert_any_call(ClearTaskState(ti_id=runtime_ti.id, all_map_indices=False)) From b78f33847a4c5a50f88744913648b9443caead23 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Tue, 19 May 2026 14:52:30 +0530 Subject: [PATCH 18/20] comments from jason --- .../src/airflow_shared/state/__init__.py | 18 +++++++++--- task-sdk/src/airflow/sdk/configuration.py | 15 ---------- .../src/airflow/sdk/execution_time/context.py | 28 +++++++++++++------ .../task_sdk/execution_time/test_context.py | 8 +++--- .../execution_time/test_task_runner.py | 8 +++--- 5 files changed, 41 insertions(+), 36 deletions(-) diff --git a/shared/state/src/airflow_shared/state/__init__.py b/shared/state/src/airflow_shared/state/__init__.py index 6c461882bc90a..5e51136d2f3c8 100644 --- a/shared/state/src/airflow_shared/state/__init__.py +++ b/shared/state/src/airflow_shared/state/__init__.py @@ -199,17 +199,22 @@ def cleanup(self) -> None: ``[state_store] default_retention_days``) and deciding what to delete. """ - def serialize_task_state_value(self, *, value: str, key: str, ti_id: str) -> str: + def serialize_task_state_to_ref(self, *, value: str, key: str, ti_id: str) -> str: """ Serialize a task state value before it is sent to the execution API for db persistence. Called by ``TaskStateAccessor.set()`` on the worker. The return value is what gets stored in the DB — typically a reference path (e.g. an S3 key) rather than the actual value. Default: return ``value`` unchanged. + + The returned reference must be deterministic — given the same ``ti_id`` and ``key`` it + must always return the same string. Do not use timestamps or random UUIDs as part of + the reference, otherwise ``delete()``/``clear()`` cannot reconstruct it and the external + object will be orphaned. """ return value - def deserialize_task_state_value(self, stored: str) -> str: + def deserialize_task_state_from_ref(self, stored: str) -> str: """ Resolve a stored task state string back to the actual value. @@ -218,7 +223,7 @@ def deserialize_task_state_value(self, stored: str) -> str: """ return stored - def serialize_asset_state_value(self, *, value: str, key: str, asset_ref: str) -> str: + def serialize_asset_state_to_ref(self, *, value: str, key: str, asset_ref: str) -> str: """ Serialize an asset state value before it is sent to the Execution API for db persistence. @@ -228,10 +233,15 @@ def serialize_asset_state_value(self, *, value: str, key: str, asset_ref: str) - ``asset_ref`` is either the asset name or URI, depending on how the accessor was constructed. It may be a URI string if the task inlet was declared as ``AssetUriRef``. + + The returned reference must be deterministic — given the same ``asset_ref`` and ``key`` it + must always return the same string. Do not use timestamps or random UUIDs as part of + the reference, otherwise ``delete()``/``clear()`` cannot reconstruct it and the external + object will be orphaned. """ return value - def deserialize_asset_state_value(self, stored: str) -> str: + def deserialize_asset_state_from_ref(self, stored: str) -> str: """ Resolve a stored asset state string back to the actual value. diff --git a/task-sdk/src/airflow/sdk/configuration.py b/task-sdk/src/airflow/sdk/configuration.py index d7b3e2c926aba..fb32f990c5880 100644 --- a/task-sdk/src/airflow/sdk/configuration.py +++ b/task-sdk/src/airflow/sdk/configuration.py @@ -211,21 +211,6 @@ def remove_all_read_configurations(self): self.remove_section(section) -def get_state_backend(): - """ - Get the state backend if configured via ``[workers] state_backend``. - - Returns the instantiated backend, or ``None`` if not configured. - """ - # Lazy import to trigger __getattr__ and lazy initialization - from airflow.sdk.configuration import conf - - class_name = conf.get("workers", "state_backend", fallback="") - if not class_name: - return None - return import_string(class_name)() - - def get_custom_secret_backend(worker_mode: bool = False): """ Get Secret Backend if defined in airflow.cfg. diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py b/task-sdk/src/airflow/sdk/execution_time/context.py index a2c0a2cc9a0d6..db8cc5e0faca4 100644 --- a/task-sdk/src/airflow/sdk/execution_time/context.py +++ b/task-sdk/src/airflow/sdk/execution_time/context.py @@ -445,9 +445,18 @@ def get(self, key, default: Any = NOTSET) -> Any: @cache def _get_worker_state_backend() -> BaseStateBackend | None: """Return the configured worker-side state backend, instantiated once and cached.""" - from airflow.sdk.configuration import get_state_backend + class_name = conf.get("workers", "state_backend", fallback="") + if not class_name: + return None + from airflow.sdk._shared.module_loading import import_string - return get_state_backend() + try: + return import_string(class_name)() + except (ImportError, AttributeError) as e: + raise ValueError( + f"Could not load worker state backend {class_name!r}. " + f"Check the [workers] state_backend config value. Error: {e}" + ) from e class TaskStateAccessor: @@ -485,7 +494,7 @@ def get(self, key: str) -> str | None: # if custom backend is configured, the stored value in DB is a reference, fetch the actual value from # custom backend using the reference backend = _get_worker_state_backend() - return backend.deserialize_task_state_value(stored) if backend else stored + return backend.deserialize_task_state_from_ref(stored) if backend else stored return None def set(self, key: str, value: str, *, retention: timedelta | None = None) -> None: @@ -515,7 +524,7 @@ def set(self, key: str, value: str, *, retention: timedelta | None = None) -> No # to the stored value to store in the DB backend = _get_worker_state_backend() stored = ( - backend.serialize_task_state_value(value=value, key=key, ti_id=str(self._ti_id)) + backend.serialize_task_state_to_ref(value=value, key=key, ti_id=str(self._ti_id)) if backend else value ) @@ -600,7 +609,7 @@ def get(self, key: str) -> str | None: # if custom backend is configured, the stored value in DB is a reference, fetch the actual value from # custom backend using the reference backend = _get_worker_state_backend() - return backend.deserialize_asset_state_value(stored) if backend else stored + return backend.deserialize_asset_state_from_ref(stored) if backend else stored return None def set(self, key: str, value: str) -> None: @@ -613,7 +622,7 @@ def set(self, key: str, value: str) -> None: backend = _get_worker_state_backend() asset_ref = self._name or self._uri or "" stored = ( - backend.serialize_asset_state_value(value=value, key=key, asset_ref=asset_ref) + backend.serialize_asset_state_to_ref(value=value, key=key, asset_ref=asset_ref) if backend else value ) @@ -636,8 +645,8 @@ def delete(self, key: str) -> None: from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS backend = _get_worker_state_backend() - # session=None signals worker-side: backend cleans up external storage only. - # DB reference is removed separately via comms below. + # custom backends handle external storage cleanup only; + # DB reference is removed by the comms call below. if backend is not None: backend.delete(AssetScope(name=self._name, uri=self._uri), key) @@ -659,7 +668,8 @@ def clear(self) -> None: from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS backend = _get_worker_state_backend() - # session=None signals worker-side: backend cleans up external storage only. + # custom backends handle external storage cleanup only; + # DB references are cleared by the comms call below. # DB references are cleared separately via comms below. if backend is not None: backend.clear(AssetScope(name=self._name, uri=self._uri)) diff --git a/task-sdk/tests/task_sdk/execution_time/test_context.py b/task-sdk/tests/task_sdk/execution_time/test_context.py index ae48d056966c7..6537bee212dfb 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_context.py +++ b/task-sdk/tests/task_sdk/execution_time/test_context.py @@ -1386,23 +1386,23 @@ def __init__(self): self._actual_key_value_store: dict[str, str] = {} # key -> actual value self.reference: dict[str, str] = {} # key -> stored ref (mem:// URI) - def serialize_task_state_value(self, *, value: str, key: str, ti_id: str) -> str: + def serialize_task_state_to_ref(self, *, value: str, key: str, ti_id: str) -> str: ref = f"mem://{ti_id}/{key}" self._actual_key_value_store[key] = value self.reference[key] = ref return ref - def deserialize_task_state_value(self, stored: str) -> str: + def deserialize_task_state_from_ref(self, stored: str) -> str: key = stored.rsplit("/", 1)[-1] return self._actual_key_value_store.get(key, stored) - def serialize_asset_state_value(self, *, value: str, key: str, asset_ref: str) -> str: + def serialize_asset_state_to_ref(self, *, value: str, key: str, asset_ref: str) -> str: ref = f"mem://{asset_ref}/{key}" self._actual_key_value_store[key] = value self.reference[key] = ref return ref - def deserialize_asset_state_value(self, stored: str) -> str: + def deserialize_asset_state_from_ref(self, stored: str) -> str: key = stored.rsplit("/", 1)[-1] return self._actual_key_value_store.get(key, stored) diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 22a29dc9dba8d..91571674f1251 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -5229,14 +5229,14 @@ def execute(self, context): mock_supervisor_comms.send.side_effect = TestTaskInstanceStateOperations._watcher_side_effect mock_backend = mock.MagicMock() - mock_backend.serialize_asset_state_value.return_value = "mem://my_asset/watermark" + mock_backend.serialize_asset_state_to_ref.return_value = "mem://my_asset/watermark" with mock.patch( "airflow.sdk.execution_time.context._get_worker_state_backend", return_value=mock_backend ): run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) - mock_backend.serialize_asset_state_value.assert_called_once_with( + mock_backend.serialize_asset_state_to_ref.assert_called_once_with( value="2026-05-01", key="watermark", asset_ref="my_asset" ) mock_supervisor_comms.send.assert_any_call( @@ -5258,14 +5258,14 @@ def execute(self, context): mock_backend = mock.MagicMock() ref = f"mem://{runtime_ti.id}/job_id" - mock_backend.serialize_task_state_value.return_value = ref + mock_backend.serialize_task_state_to_ref.return_value = ref with mock.patch( "airflow.sdk.execution_time.context._get_worker_state_backend", return_value=mock_backend ): run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) - mock_backend.serialize_task_state_value.assert_called_once_with( + mock_backend.serialize_task_state_to_ref.assert_called_once_with( value="app_001", key="job_id", ti_id=str(runtime_ti.id) ) mock_supervisor_comms.send.assert_any_call(SetTaskState(ti_id=runtime_ti.id, key="job_id", value=ref)) From 4ce69093577c34f92e1cdc34b74db061f3e23c16 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Tue, 19 May 2026 17:44:21 +0530 Subject: [PATCH 19/20] fixing tests --- shared/state/tests/state/test_state.py | 8 ++++---- .../task_sdk/execution_time/test_context.py | 17 ++++++++++++----- .../task_sdk/execution_time/test_task_runner.py | 12 +++++++++--- 3 files changed, 25 insertions(+), 12 deletions(-) diff --git a/shared/state/tests/state/test_state.py b/shared/state/tests/state/test_state.py index 3cfbdc61be7c2..489ba185093cd 100644 --- a/shared/state/tests/state/test_state.py +++ b/shared/state/tests/state/test_state.py @@ -73,8 +73,8 @@ def test_abstract_methods_cover_full_interface(self): def test_task_state_serialize_deserialize_round_trip(self, backend): original = "app_1234" - serialized = backend.serialize_task_state_value(value=original, key="job_id", ti_id="abc-123") - deserialized = backend.deserialize_task_state_value(serialized) + serialized = backend.serialize_task_state_to_ref(value=original, key="job_id", ti_id="abc-123") + deserialized = backend.deserialize_task_state_from_ref(serialized) assert deserialized == original def test_custom_backend_overrides_task_state_ser_deser(self): @@ -106,10 +106,10 @@ def deserialize_task_state_value(self, stored): def test_asset_state_serialize_deserialize_round_trip(self, backend): original = "2026-05-01" - serialized = backend.serialize_asset_state_value( + serialized = backend.serialize_asset_state_to_ref( value="2026-05-01", key="watermark", asset_ref="my_asset" ) - deserialized = backend.deserialize_asset_state_value(serialized) + deserialized = backend.deserialize_asset_state_from_ref(serialized) assert deserialized == original def test_custom_backend_overrides_asset_state_ser_deser(self): diff --git a/task-sdk/tests/task_sdk/execution_time/test_context.py b/task-sdk/tests/task_sdk/execution_time/test_context.py index 6537bee212dfb..285f1fd9763e9 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_context.py +++ b/task-sdk/tests/task_sdk/execution_time/test_context.py @@ -1121,7 +1121,9 @@ def test_set_with_retention_computes_expires_at(self, mock_supervisor_comms, tim now = datetime(2026, 5, 14, 12, 0, 0, tzinfo=dt_timezone.utc) time_machine.move_to(now, tick=False) - TaskStateAccessor(ti_id=self.TI_ID).set("job_id", "app_001", retention=timedelta(days=7)) + TaskStateAccessor(ti_id=self.TI_ID, scope=self.SCOPE).set( + "job_id", "app_001", retention=timedelta(days=7) + ) mock_supervisor_comms.send.assert_called_once_with( SetTaskState( @@ -1137,7 +1139,7 @@ def test_set_with_never_expire_sends_null_expires_at(self, mock_supervisor_comms mock_supervisor_comms.send.return_value = OKResponse(ok=True) - TaskStateAccessor(ti_id=self.TI_ID).set("job_id", "app_001", retention=NEVER_EXPIRE) + TaskStateAccessor(ti_id=self.TI_ID, scope=self.SCOPE).set("job_id", "app_001", retention=NEVER_EXPIRE) mock_supervisor_comms.send.assert_called_once_with( SetTaskState(ti_id=self.TI_ID, key="job_id", value="app_001", expires_at=None) @@ -1148,7 +1150,7 @@ def test_set_global_default_zero_sends_null_expires_at(self, mock_supervisor_com mock_supervisor_comms.send.return_value = OKResponse(ok=True) with conf_vars({("state_store", "default_retention_days"): "0"}): - TaskStateAccessor(ti_id=self.TI_ID).set("job_id", "app_001") + TaskStateAccessor(ti_id=self.TI_ID, scope=self.SCOPE).set("job_id", "app_001") mock_supervisor_comms.send.assert_called_once_with( SetTaskState(ti_id=self.TI_ID, key="job_id", value="app_001", expires_at=None) @@ -1436,15 +1438,20 @@ def backend(self): ): yield b - def test_set_returns_reference_to_storage(self, mock_supervisor_comms, backend): + def test_set_returns_reference_to_storage(self, mock_supervisor_comms, backend, time_machine): """set() stores actual value in backend and sends mem:// reference via comms.""" mock_supervisor_comms.send.return_value = OKResponse(ok=True) expected_ref = f"mem://{self.TI_ID}/job_id" + frozen_dt = datetime(2026, 1, 1, 12, 0, 0, tzinfo=dt_timezone.utc) + time_machine.move_to(frozen_dt, tick=False) + TaskStateAccessor(ti_id=self.TI_ID, scope=self.SCOPE).set("job_id", "app_001") # comms message has the mem:// reference, not the actual value mock_supervisor_comms.send.assert_called_once_with( - SetTaskState(ti_id=self.TI_ID, key="job_id", value=expected_ref) + SetTaskState( + ti_id=self.TI_ID, key="job_id", value=expected_ref, expires_at=frozen_dt + timedelta(days=30) + ) ) # actual value is stored on the backend, reference is stored for DB assert backend._actual_key_value_store["job_id"] == "app_001" diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 91571674f1251..0d3b0df6abd84 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -5244,7 +5244,7 @@ def execute(self, context): ) def test_task_state_set_sends_reference_via_custom_backend( - self, create_runtime_ti, mock_supervisor_comms + self, create_runtime_ti, mock_supervisor_comms, time_machine ): """When a worker backend is configured, task state set() sends a reference, not the actual value.""" @@ -5252,6 +5252,8 @@ class MyOperator(BaseOperator): def execute(self, context): context["task_state"].set("job_id", "app_001") + frozen_dt = datetime(2026, 1, 1, 12, 0, 0, tzinfo=dt_timezone.utc) + time_machine.move_to(frozen_dt, tick=False) task = MyOperator(task_id="t") runtime_ti = create_runtime_ti(task=task) mock_supervisor_comms.send.side_effect = TestTaskInstanceStateOperations._watcher_side_effect @@ -5268,7 +5270,11 @@ def execute(self, context): mock_backend.serialize_task_state_to_ref.assert_called_once_with( value="app_001", key="job_id", ti_id=str(runtime_ti.id) ) - mock_supervisor_comms.send.assert_any_call(SetTaskState(ti_id=runtime_ti.id, key="job_id", value=ref)) + mock_supervisor_comms.send.assert_any_call( + SetTaskState( + ti_id=runtime_ti.id, key="job_id", value=ref, expires_at=frozen_dt + timedelta(days=30) + ) + ) @conf_vars({("state_store", "clear_on_success"): "True"}) def test_clear_on_success_calls_clear_when_worker_backend_configured( @@ -5285,7 +5291,7 @@ def execute(self, context): runtime_ti = create_runtime_ti(task=task) with mock.patch( - "airflow.sdk.execution_time.context._get_worker_state_backend", return_value=mock_backend + "airflow.sdk.execution_time.task_runner._get_worker_state_backend", return_value=mock_backend ): run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) From 59e33f7b39158ea08dd5359cfffe83ee16645748 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Wed, 20 May 2026 12:22:59 +0530 Subject: [PATCH 20/20] handling review comments from kaxil --- .../src/airflow/config_templates/config.yml | 5 ++- .../src/airflow_shared/state/__init__.py | 4 ++ shared/state/tests/state/test_state.py | 39 ++++++++++++------ .../src/airflow/sdk/execution_time/context.py | 40 ++++++++++++------- .../airflow/sdk/execution_time/task_runner.py | 5 +-- .../execution_time/test_task_runner.py | 13 ++++-- 6 files changed, 68 insertions(+), 38 deletions(-) diff --git a/airflow-core/src/airflow/config_templates/config.yml b/airflow-core/src/airflow/config_templates/config.yml index a5033fe20dc30..0712b6746115f 100644 --- a/airflow-core/src/airflow/config_templates/config.yml +++ b/airflow-core/src/airflow/config_templates/config.yml @@ -1891,8 +1891,9 @@ workers: default: "" state_backend: description: | - Full class name of the state backend to use on workers for direct task state access, - bypassing the execution API. + Full class name of a custom worker-side state backend. When set, task state values are + routed through this backend so large payloads or credentialed storage stay on worker + infrastructure. The Execution API still records a reference string in the database. Leave empty (default) to use the standard path through the task sdk supervisor. version_added: 3.3.0 diff --git a/shared/state/src/airflow_shared/state/__init__.py b/shared/state/src/airflow_shared/state/__init__.py index 5e51136d2f3c8..7aa9fcba8372d 100644 --- a/shared/state/src/airflow_shared/state/__init__.py +++ b/shared/state/src/airflow_shared/state/__init__.py @@ -63,6 +63,10 @@ class AssetScope: name: str | None = None uri: str | None = None + def __post_init__(self) -> None: + if self.asset_id is None and self.name is None and self.uri is None: + raise ValueError("AssetScope requires at least one of: asset_id, name, or uri") + StateScope = TaskScope | AssetScope diff --git a/shared/state/tests/state/test_state.py b/shared/state/tests/state/test_state.py index 489ba185093cd..1ea31194e2788 100644 --- a/shared/state/tests/state/test_state.py +++ b/shared/state/tests/state/test_state.py @@ -18,7 +18,22 @@ import pytest -from airflow_shared.state import BaseStateBackend, StateScope +from airflow_shared.state import AssetScope, BaseStateBackend, StateScope + + +class TestAssetScope: + def test_requires_at_least_one_identifier(self): + with pytest.raises(ValueError, match="at least one of"): + AssetScope() + + def test_asset_id_alone_is_valid(self): + AssetScope(asset_id=1) + + def test_name_alone_is_valid(self): + AssetScope(name="my_asset") + + def test_uri_alone_is_valid(self): + AssetScope(uri="s3://bucket/key") class TestBaseStateBackend: @@ -88,19 +103,18 @@ async def aset(self, scope, key, value): ... async def adelete(self, scope, key): ... async def aclear(self, scope, *, all_map_indices=False): ... - def serialize_task_state_value(self, *, value, key, ti_id): + def serialize_task_state_to_ref(self, *, value, key, ti_id): return f"s3://bucket/{ti_id}/{key}" - def deserialize_task_state_value(self, stored): + def deserialize_task_state_from_ref(self, stored): return f"fetched:{stored}" b = MyBackend() - assert ( - b.serialize_task_state_value(value="app_1234", key="job_id", ti_id="abc-123") - == "s3://bucket/abc-123/job_id" + assert b.serialize_task_state_to_ref(value="app_1234", key="job_id", ti_id="abc-123") == ( + "s3://bucket/abc-123/job_id" ) assert ( - b.deserialize_task_state_value("s3://bucket/abc-123/job_id") + b.deserialize_task_state_from_ref("s3://bucket/abc-123/job_id") == "fetched:s3://bucket/abc-123/job_id" ) @@ -123,18 +137,17 @@ async def aset(self, scope, key, value): ... async def adelete(self, scope, key): ... async def aclear(self, scope, *, all_map_indices=False): ... - def serialize_asset_state_value(self, *, value, key, asset_ref): + def serialize_asset_state_to_ref(self, *, value, key, asset_ref): return f"s3://bucket/assets/{asset_ref}/{key}" - def deserialize_asset_state_value(self, stored): + def deserialize_asset_state_from_ref(self, stored): return f"resolved:{stored}" b = MyBackend() - assert ( - b.serialize_asset_state_value(value="2026-05-01", key="watermark", asset_ref="my_asset") - == "s3://bucket/assets/my_asset/watermark" + assert b.serialize_asset_state_to_ref(value="2026-05-01", key="watermark", asset_ref="my_asset") == ( + "s3://bucket/assets/my_asset/watermark" ) assert ( - b.deserialize_asset_state_value("s3://bucket/assets/my_asset/watermark") + b.deserialize_asset_state_from_ref("s3://bucket/assets/my_asset/watermark") == "resolved:s3://bucket/assets/my_asset/watermark" ) diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py b/task-sdk/src/airflow/sdk/execution_time/context.py index db8cc5e0faca4..1d5b0694667a0 100644 --- a/task-sdk/src/airflow/sdk/execution_time/context.py +++ b/task-sdk/src/airflow/sdk/execution_time/context.py @@ -536,10 +536,12 @@ def delete(self, key: str) -> None: from airflow.sdk.execution_time.comms import DeleteTaskState from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + # cleanup the DB ref first, if backend cleanup fails after this, the ref is gone and + # deterministic keys are recoverable on next set(). + SUPERVISOR_COMMS.send(DeleteTaskState(ti_id=self._ti_id, key=key)) backend = _get_worker_state_backend() if backend is not None: backend.delete(self._scope, key) - SUPERVISOR_COMMS.send(DeleteTaskState(ti_id=self._ti_id, key=key)) def clear(self, all_map_indices: bool = False) -> None: """ @@ -552,10 +554,23 @@ def clear(self, all_map_indices: bool = False) -> None: from airflow.sdk.execution_time.comms import ClearTaskState from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + # cleanup the DB ref first, if backend cleanup fails after this, the ref is gone and + # deterministic keys are recoverable on next set(). + SUPERVISOR_COMMS.send(ClearTaskState(ti_id=self._ti_id, all_map_indices=all_map_indices)) backend = _get_worker_state_backend() if backend is not None: backend.clear(self._scope, all_map_indices=all_map_indices) - SUPERVISOR_COMMS.send(ClearTaskState(ti_id=self._ti_id, all_map_indices=all_map_indices)) + + def _clear_backend_only(self) -> None: + """ + Clear external storage via the worker backend without sending a comms message. + + Used by clear_on_success: the server already clears DB rows as part of SucceedTask, + so the comms round-trip is redundant. + """ + backend = _get_worker_state_backend() + if backend is not None: + backend.clear(self._scope) class AssetStateAccessor: @@ -644,18 +659,17 @@ def delete(self, key: str) -> None: ) from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - backend = _get_worker_state_backend() - # custom backends handle external storage cleanup only; - # DB reference is removed by the comms call below. - if backend is not None: - backend.delete(AssetScope(name=self._name, uri=self._uri), key) - msg: ToSupervisor if self._name: msg = DeleteAssetStateByName(name=self._name, key=key) elif self._uri: msg = DeleteAssetStateByUri(uri=self._uri, key=key) + # DB ref first: if backend cleanup fails after this, the ref is gone and + # deterministic keys are recoverable on next set(). SUPERVISOR_COMMS.send(msg) + backend = _get_worker_state_backend() + if backend is not None: + backend.delete(AssetScope(name=self._name, uri=self._uri), key) def clear(self) -> None: """Delete all state keys for this asset.""" @@ -667,19 +681,15 @@ def clear(self) -> None: ) from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - backend = _get_worker_state_backend() - # custom backends handle external storage cleanup only; - # DB references are cleared by the comms call below. - # DB references are cleared separately via comms below. - if backend is not None: - backend.clear(AssetScope(name=self._name, uri=self._uri)) - msg: ToSupervisor if self._name: msg = ClearAssetStateByName(name=self._name) elif self._uri: msg = ClearAssetStateByUri(uri=self._uri) SUPERVISOR_COMMS.send(msg) + backend = _get_worker_state_backend() + if backend is not None: + backend.clear(AssetScope(name=self._name, uri=self._uri)) class AssetStateAccessors: diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 35b4e6267e92c..761ce3714f564 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -119,7 +119,6 @@ TaskStateAccessor, TriggeringAssetEventsAccessor, VariableAccessor, - _get_worker_state_backend, context_get_outlet_events, context_to_airflow_vars, get_previous_dagrun_success, @@ -1465,9 +1464,7 @@ def _handle_current_task_success( if conf.getboolean("state_store", "clear_on_success"): log.info("Task state will be cleared by the server because clear_on_success is enabled.") - if _get_worker_state_backend() is not None: - # clear the task state keys for custom state backends configured on worker side - context["task_state"].clear() + context["task_state"]._clear_backend_only() msg = SucceedTask( end_date=end_date, diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 0d3b0df6abd84..1558e90ae5e93 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -5277,10 +5277,10 @@ def execute(self, context): ) @conf_vars({("state_store", "clear_on_success"): "True"}) - def test_clear_on_success_calls_clear_when_worker_backend_configured( + def test_clear_on_success_clears_backend_without_comms_roundtrip( self, create_runtime_ti, mock_supervisor_comms ): - """When clear_on_success=True and a worker backend is configured, clear() is called on task success.""" + """clear_on_success calls backend.clear() directly without sending ClearTaskState comms.""" mock_backend = mock.MagicMock() class MyOperator(BaseOperator): @@ -5291,8 +5291,13 @@ def execute(self, context): runtime_ti = create_runtime_ti(task=task) with mock.patch( - "airflow.sdk.execution_time.task_runner._get_worker_state_backend", return_value=mock_backend + "airflow.sdk.execution_time.context._get_worker_state_backend", return_value=mock_backend ): run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) - mock_supervisor_comms.send.assert_any_call(ClearTaskState(ti_id=runtime_ti.id, all_map_indices=False)) + mock_backend.clear.assert_called_once() + sent_types = [ + type(call.kwargs.get("msg") or (call.args[0] if call.args else None)) + for call in mock_supervisor_comms.send.call_args_list + ] + assert ClearTaskState not in sent_types