Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions airflow-core/src/airflow/config_templates/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1889,6 +1889,17 @@ workers:
sensitive: true
example: ~
default: ""
state_backend:
description: |
Full class name of a custom worker-side state backend. When set, task state values are
routed through this backend so large payloads or credentialed storage stay on worker
infrastructure. The Execution API still records a reference string in the database.

Leave empty (default) to use the standard path through the task sdk supervisor.
version_added: 3.3.0
type: string
example: "mypackage.state.S3StateBackend"
default: ""
min_heartbeat_interval:
description: |
The minimum interval (in seconds) at which the worker checks the task instance's
Expand Down
71 changes: 69 additions & 2 deletions shared/state/src/airflow_shared/state/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,25 @@ class TaskScope:

@dataclass(frozen=True)
class AssetScope:
"""Identifies the state namespace for an asset."""
"""
Identifies the state namespace for an asset.

Server-side backends receive ``asset_id``. Worker-side backends receive ``name`` or ``uri``
since workers do not have access to the integer ``asset_id``.

Note: ``name`` and ``uri`` are not guaranteed to be unique over time — if an asset is
deactivated and a new one created with the same name, both share the same ``name`` value.
State for inactive assets is cleaned up by the orphan GC pass; until then, stale rows exist
in the DB but cannot be written to (the Execution API resolver filters to active assets only).
"""

asset_id: int | None = None
Comment thread
amoghrajesh marked this conversation as resolved.
name: str | None = None
uri: str | None = None

asset_id: int
def __post_init__(self) -> None:
if self.asset_id is None and self.name is None and self.uri is None:
raise ValueError("AssetScope requires at least one of: asset_id, name, or uri")


StateScope = TaskScope | AssetScope
Expand Down Expand Up @@ -186,3 +202,54 @@ def cleanup(self) -> None:
retention policy. The backend is responsible for reading any relevant config (e.g.
``[state_store] default_retention_days``) and deciding what to delete.
"""

def serialize_task_state_to_ref(self, *, value: str, key: str, ti_id: str) -> str:
"""
Serialize a task state value before it is sent to the execution API for db persistence.

Called by ``TaskStateAccessor.set()`` on the worker. The return value is what gets
stored in the DB — typically a reference path (e.g. an S3 key) rather than the
actual value. Default: return ``value`` unchanged.

The returned reference must be deterministic — given the same ``ti_id`` and ``key`` it
must always return the same string. Do not use timestamps or random UUIDs as part of
the reference, otherwise ``delete()``/``clear()`` cannot reconstruct it and the external
object will be orphaned.
"""
return value

def deserialize_task_state_from_ref(self, stored: str) -> str:
"""
Resolve a stored task state string back to the actual value.

Called by ``TaskStateAccessor.get()`` after the stored string is retrieved from
the execution API. Default: return ``stored`` unchanged.
"""
return stored

def serialize_asset_state_to_ref(self, *, value: str, key: str, asset_ref: str) -> str:
"""
Serialize an asset state value before it is sent to the Execution API for db persistence.

Called by ``AssetStateAccessor.set()`` on the worker. The return value is what gets
stored in the DB — typically a reference path rather than the actual value.
Default: return ``value`` unchanged.

``asset_ref`` is either the asset name or URI, depending on how the accessor was
constructed. It may be a URI string if the task inlet was declared as ``AssetUriRef``.

The returned reference must be deterministic — given the same ``asset_ref`` and ``key`` it
must always return the same string. Do not use timestamps or random UUIDs as part of
the reference, otherwise ``delete()``/``clear()`` cannot reconstruct it and the external
object will be orphaned.
"""
return value

def deserialize_asset_state_from_ref(self, stored: str) -> str:
"""
Resolve a stored asset state string back to the actual value.

Called by ``AssetStateAccessor.get()`` after the stored string is retrieved from
the Execution API. Default: return ``stored`` unchanged.
"""
return stored
83 changes: 82 additions & 1 deletion shared/state/tests/state/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,22 @@

import pytest

from airflow_shared.state import BaseStateBackend, StateScope
from airflow_shared.state import AssetScope, BaseStateBackend, StateScope


class TestAssetScope:
def test_requires_at_least_one_identifier(self):
with pytest.raises(ValueError, match="at least one of"):
AssetScope()

def test_asset_id_alone_is_valid(self):
AssetScope(asset_id=1)

def test_name_alone_is_valid(self):
AssetScope(name="my_asset")

def test_uri_alone_is_valid(self):
AssetScope(uri="s3://bucket/key")


class TestBaseStateBackend:
Expand Down Expand Up @@ -70,3 +85,69 @@ def test_abstract_methods_cover_full_interface(self):
"""BaseStateBackend enforces all 8 sync+async methods as abstract."""
expected = {"get", "set", "delete", "clear", "aget", "aset", "adelete", "aclear"}
assert BaseStateBackend.__abstractmethods__ == expected

def test_task_state_serialize_deserialize_round_trip(self, backend):
original = "app_1234"
serialized = backend.serialize_task_state_to_ref(value=original, key="job_id", ti_id="abc-123")
deserialized = backend.deserialize_task_state_from_ref(serialized)
assert deserialized == original

def test_custom_backend_overrides_task_state_ser_deser(self):
class MyBackend(BaseStateBackend):
def get(self, scope, key): ...
def set(self, scope, key, value): ...
def delete(self, scope, key): ...
def clear(self, scope, *, all_map_indices=False): ...
async def aget(self, scope, key): ...
async def aset(self, scope, key, value): ...
async def adelete(self, scope, key): ...
async def aclear(self, scope, *, all_map_indices=False): ...

def serialize_task_state_to_ref(self, *, value, key, ti_id):
return f"s3://bucket/{ti_id}/{key}"

def deserialize_task_state_from_ref(self, stored):
return f"fetched:{stored}"

b = MyBackend()
assert b.serialize_task_state_to_ref(value="app_1234", key="job_id", ti_id="abc-123") == (
"s3://bucket/abc-123/job_id"
)
assert (
b.deserialize_task_state_from_ref("s3://bucket/abc-123/job_id")
== "fetched:s3://bucket/abc-123/job_id"
)

def test_asset_state_serialize_deserialize_round_trip(self, backend):
original = "2026-05-01"
serialized = backend.serialize_asset_state_to_ref(
value="2026-05-01", key="watermark", asset_ref="my_asset"
)
deserialized = backend.deserialize_asset_state_from_ref(serialized)
assert deserialized == original

def test_custom_backend_overrides_asset_state_ser_deser(self):
class MyBackend(BaseStateBackend):
def get(self, scope, key): ...
def set(self, scope, key, value): ...
def delete(self, scope, key): ...
def clear(self, scope, *, all_map_indices=False): ...
async def aget(self, scope, key): ...
async def aset(self, scope, key, value): ...
async def adelete(self, scope, key): ...
async def aclear(self, scope, *, all_map_indices=False): ...

def serialize_asset_state_to_ref(self, *, value, key, asset_ref):
return f"s3://bucket/assets/{asset_ref}/{key}"
Comment thread
amoghrajesh marked this conversation as resolved.

def deserialize_asset_state_from_ref(self, stored):
return f"resolved:{stored}"

b = MyBackend()
assert b.serialize_asset_state_to_ref(value="2026-05-01", key="watermark", asset_ref="my_asset") == (
"s3://bucket/assets/my_asset/watermark"
)
assert (
b.deserialize_asset_state_from_ref("s3://bucket/assets/my_asset/watermark")
== "resolved:s3://bucket/assets/my_asset/watermark"
)
3 changes: 3 additions & 0 deletions task-sdk/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ path = "src/airflow/sdk/__init__.py"
"../shared/listeners/src/airflow_shared/listeners" = "src/airflow/sdk/_shared/listeners"
"../shared/plugins_manager/src/airflow_shared/plugins_manager" = "src/airflow/sdk/_shared/plugins_manager"
"../shared/providers_discovery/src/airflow_shared/providers_discovery" = "src/airflow/sdk/_shared/providers_discovery"
"../shared/state/src/airflow_shared/state" = "src/airflow/sdk/_shared/state"
"../shared/template_rendering/src/airflow_shared/template_rendering" = "src/airflow/sdk/_shared/template_rendering"

[tool.hatch.build.targets.wheel]
Expand Down Expand Up @@ -240,6 +241,7 @@ apache-airflow = {workspace = true}
apache-airflow-devel-common = {workspace = true}
apache-airflow-providers-common-sql = {workspace = true}
apache-airflow-providers-standard = {workspace = true}
apache-airflow-shared-state = {workspace = true}

# To use:
#
Expand Down Expand Up @@ -316,6 +318,7 @@ shared_distributions = [
"apache-airflow-shared-secrets-backend",
"apache-airflow-shared-secrets-masker",
"apache-airflow-shared-serialization",
"apache-airflow-shared-state",
"apache-airflow-shared-timezones",
"apache-airflow-shared-observability",
"apache-airflow-shared-plugins-manager",
Expand Down
1 change: 1 addition & 0 deletions task-sdk/src/airflow/sdk/_shared/state
3 changes: 1 addition & 2 deletions task-sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
OKResponse,
PreviousDagRunResult,
PreviousTIResult,
RescheduleTask,
SkipDownstreamTasks,
TaskRescheduleStartDate,
TICount,
Expand All @@ -104,8 +105,6 @@
from datetime import datetime
from typing import ParamSpec

from airflow.sdk.execution_time.comms import RescheduleTask

P = ParamSpec("P")
T = TypeVar("T")

Expand Down
3 changes: 1 addition & 2 deletions task-sdk/src/airflow/sdk/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
configure_parser_from_configuration_description,
expand_env_var,
)
from airflow.sdk._shared.module_loading import import_string
from airflow.sdk.execution_time.secrets import _SERVER_DEFAULT_SECRETS_SEARCH_PATH

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -236,8 +237,6 @@ def initialize_secrets_backends(

Uses SDK's conf instead of Core's conf.
"""
from airflow.sdk._shared.module_loading import import_string

backend_list = []
worker_mode = False
# Determine worker mode - if default_backends is not the server default, it's worker mode
Expand Down
Loading
Loading