From d4a9ef91cf2585d0aeada0a21fdc2c2858caca18 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Mon, 10 Mar 2025 15:43:00 +0800 Subject: [PATCH 01/10] Rewrite asset event registration There are many issues in the original implementation when you send something that's not a simple asset, I ended up almost rewritten the whole thing... Now, the task runner collects everything the user writes into outlet_events, and send all of them with the task's outlets as declared by the user verbatim to the API server. The server does all the resolution and filtering instead. --- .../execution_api/datamodels/taskinstance.py | 2 +- airflow/models/taskinstance.py | 180 +++++++++--------- .../airflow/sdk/api/datamodels/_generated.py | 2 +- .../airflow/sdk/definitions/asset/__init__.py | 2 +- .../src/airflow/sdk/execution_time/context.py | 4 +- .../airflow/sdk/execution_time/task_runner.py | 47 ++--- task-sdk/src/airflow/sdk/types.py | 8 +- 7 files changed, 121 insertions(+), 124 deletions(-) diff --git a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py index f6e586347d36b..97daf54b0c5e0 100644 --- a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py +++ b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py @@ -97,7 +97,7 @@ class TISuccessStatePayload(StrictBaseModel): """When the task completed executing""" task_outlets: Annotated[list[AssetProfile], Field(default_factory=list)] - outlet_events: Annotated[list[Any], Field(default_factory=list)] + outlet_events: Annotated[list[dict[str, Any]], Field(default_factory=list)] class TITargetStatePayload(StrictBaseModel): diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 94e573f94abd7..7ffacc12bdc9f 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -80,7 +80,6 @@ from airflow.exceptions import ( AirflowException, AirflowFailException, - AirflowInactiveAssetAddedToAssetAliasException, AirflowInactiveAssetInInletOrOutletException, AirflowRescheduleException, AirflowSensorTimeout, @@ -94,7 +93,7 @@ XComForMappingNotPushed, ) from airflow.listeners.listener import get_listener_manager -from airflow.models.asset import AssetActive, AssetEvent, AssetModel +from airflow.models.asset import AssetActive, AssetAliasModel, AssetEvent, AssetModel from airflow.models.base import Base, StringID, TaskInstanceDependencies from airflow.models.dagbag import DagBag from airflow.models.log import Log @@ -2744,104 +2743,105 @@ def _run_raw_task( def register_asset_changes_in_db( ti: TaskInstance, task_outlets: list[AssetProfile], - outlet_events: list[Any], + outlet_events: list[dict[str, Any]], session: Session = NEW_SESSION, ) -> None: - # One task only triggers one asset event for each asset with the same extra. - # This tuple[asset uri, extra] to sets alias names mapping is used to find whether - # there're assets with same uri but different extra that we need to emit more than one asset events. - asset_alias_names: dict[tuple[AssetUniqueKey, frozenset], set[str]] = defaultdict(set) - asset_name_refs: set[str] = set() - asset_uri_refs: set[str] = set() - - for obj in task_outlets: - ti.log.debug("outlet obj %s", obj) - # Lineage can have other types of objects besides assets - if obj.type == Asset.__name__: + asset_keys = {(o.name, o.uri) for o in task_outlets if o.type == Asset.__name__} + asset_name_refs = {o.name for o in task_outlets if o.type == AssetNameRef.__name__} + asset_uri_refs = {o.uri for o in task_outlets if o.type == AssetUriRef.__name__} + + @cache + def asset_event_extras_by_unique_key() -> dict[AssetUniqueKey, dict]: + return { + AssetUniqueKey(**event["dest_asset_key"]): event["extra"] + for event in outlet_events + if "source_alias_name" not in event + } + + @cache + def asset_event_extras_by_name() -> dict[str, dict]: + return {key.name: extra for key, extra in asset_event_extras_by_unique_key().items()} + + @cache + def asset_event_extras_by_uri() -> dict[str, dict]: + return {key.uri: extra for key, extra in asset_event_extras_by_unique_key().items()} + + asset_models: Iterable[AssetModel] + if asset_keys: + asset_models = session.scalars( + select(AssetModel).where(tuple_(AssetModel.name, AssetModel.uri).in_(asset_keys)) + ) + for am in asset_models: + ti.log.debug("register event for asset %s", am) asset_manager.register_asset_change( task_instance=ti, - asset=Asset(name=obj.name, uri=obj.uri), # type: ignore - extra=outlet_events[0]["extra"], + asset=am, + extra=asset_event_extras_by_unique_key().get(AssetUniqueKey.from_asset(am)), session=session, ) - elif obj.type == AssetNameRef.__name__: - asset_name_refs.add(obj.name) # type: ignore - elif obj.type == AssetUriRef.__name__: - asset_uri_refs.add(obj.uri) # type: ignore - elif obj.type == AssetAlias.__name__: - outlet_events = list( - map( - lambda event: {**event, "dest_asset_key": AssetUniqueKey(**event["dest_asset_key"])}, - outlet_events, - ) - ) - for asset_alias_event in outlet_events: - asset_alias_name = asset_alias_event["source_alias_name"] - asset_unique_key = asset_alias_event["dest_asset_key"] - frozen_extra = frozenset(asset_alias_event["extra"].items()) - asset_alias_names[(asset_unique_key, frozen_extra)].add(asset_alias_name) - - asset_unique_keys = {key for key, _ in asset_alias_names} - existing_aliased_assets: set[AssetUniqueKey] = { - AssetUniqueKey.from_asset(asset_obj) - for asset_obj in session.scalars( - select(AssetModel).where( - tuple_(AssetModel.name, AssetModel.uri).in_( - attrs.astuple(key) for key in asset_unique_keys - ) - ) - ) - } - inactive_asset_unique_keys = TaskInstance._get_inactive_asset_unique_keys( - asset_unique_keys={key for key in asset_unique_keys if key in existing_aliased_assets}, - session=session, - ) - if inactive_asset_unique_keys: - raise AirflowInactiveAssetAddedToAssetAliasException(inactive_asset_unique_keys) - - if missing_assets := [ - asset_unique_key.to_asset() - for asset_unique_key, _ in asset_alias_names - if asset_unique_key not in existing_aliased_assets - ]: - asset_manager.create_assets(missing_assets, session=session) - ti.log.warning("Created new assets for alias reference: %s", missing_assets) - session.flush() # Needed because we need the id for fk. - - for (unique_key, extra_items), alias_names in asset_alias_names.items(): - ti.log.info( - 'Creating event for %r through aliases "%s"', - unique_key, - ", ".join(alias_names), + if asset_name_refs: + asset_models = session.scalars( + select(AssetModel).where(AssetModel.name.in_(asset_name_refs), AssetModel.active.has()) ) - asset_manager.register_asset_change( - task_instance=ti, - asset=unique_key, - aliases=[AssetAlias(name=name) for name in alias_names], - extra=dict(extra_items), - session=session, - source_alias_names=alias_names, + for am in asset_models: + ti.log.debug("register event for asset name ref %s", am) + asset_manager.register_asset_change( + task_instance=ti, + asset=am, + extra=asset_event_extras_by_name().get(am.name), + session=session, + ) + if asset_uri_refs: + asset_models = session.scalars( + select(AssetModel).where(AssetModel.uri.in_(asset_uri_refs), AssetModel.active.has()) ) + for am in asset_models: + ti.log.debug("register event for asset uri ref %s", am) + asset_manager.register_asset_change( + task_instance=ti, + asset=am, + extra=asset_event_extras_by_uri().get(am.uri), + session=session, + ) - # Handle events derived from references. - asset_stmt = select(AssetModel).where(AssetModel.name.in_(asset_name_refs), AssetModel.active.has()) - for asset_model in session.scalars(asset_stmt): - ti.log.info("Creating event through asset name reference %r", asset_model.name) - asset_manager.register_asset_change( - task_instance=ti, - asset=asset_model, - extra=outlet_events[asset_model].extra, - session=session, - ) - asset_stmt = select(AssetModel).where(AssetModel.uri.in_(asset_uri_refs), AssetModel.active.has()) - for asset_model in session.scalars(asset_stmt): - ti.log.info("Creating event for through asset URI reference %r", asset_model.uri) - asset_manager.register_asset_change( - task_instance=ti, - asset=asset_model, - extra=outlet_events[asset_model].extra, - session=session, + @cache + def asset_event_extras_from_aliases() -> dict[tuple[AssetUniqueKey, frozenset], set[str]]: + d = defaultdict(set) + for event in outlet_events: + try: + alias_name = event["source_alias_name"] + except KeyError: + continue + asset_key = AssetUniqueKey(**event["dest_asset_key"]) + extra_key = frozenset(event["extra"].items()) + d[asset_key, extra_key].add(alias_name) + return d + + outlet_alias_names = {o.name for o in task_outlets if o.type == AssetAlias.__name__ if o.name} + if outlet_alias_names and (event_from_aliases := asset_event_extras_from_aliases()): + asset_alias_models: dict[str, AssetAliasModel] = dict( + session.execute( + select(AssetAliasModel.name, AssetAliasModel).where( + AssetAliasModel.name.in_(outlet_alias_names) + ) + ) ) + for (asset_key, extra_key), event_aliase_names in event_from_aliases.items(): + aliases = [ + alias_model + for alias_model in (asset_alias_models.get(n) for n in event_aliase_names) + if alias_model is not None + ] + if not aliases: + continue + ti.log.debug("register event for asset %s with alias %s", asset_key, aliases) + asset_manager.register_asset_change( + task_instance=ti, + asset=asset_key, + aliases=aliases, + extra=dict(extra_key), + session=session, + ) def _execute_task_with_callbacks(self, context: Context, test_mode: bool = False, *, session: Session): """Prepare Task for Execution.""" diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py index b949644274482..74d7fade14872 100644 --- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -209,7 +209,7 @@ class TISuccessStatePayload(BaseModel): state: Annotated[Literal["success"] | None, Field(title="State")] = "success" end_date: Annotated[datetime, Field(title="End Date")] task_outlets: Annotated[list[AssetProfile] | None, Field(title="Task Outlets")] = None - outlet_events: Annotated[list | None, Field(title="Outlet Events")] = None + outlet_events: Annotated[list[dict[str, Any]] | None, Field(title="Outlet Events")] = None class TITargetStatePayload(BaseModel): diff --git a/task-sdk/src/airflow/sdk/definitions/asset/__init__.py b/task-sdk/src/airflow/sdk/definitions/asset/__init__.py index ea7a2c9e8ad77..32204bd4139eb 100644 --- a/task-sdk/src/airflow/sdk/definitions/asset/__init__.py +++ b/task-sdk/src/airflow/sdk/definitions/asset/__init__.py @@ -637,7 +637,7 @@ def as_expression(self) -> Any: @attrs.define -class AssetAliasEvent: +class AssetAliasEvent(attrs.AttrsInstance): """Representation of asset event to be triggered by an asset alias.""" source_alias_name: str diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py b/task-sdk/src/airflow/sdk/execution_time/context.py index 76b86e28d1fe2..9ac0b2fb27c71 100644 --- a/task-sdk/src/airflow/sdk/execution_time/context.py +++ b/task-sdk/src/airflow/sdk/execution_time/context.py @@ -226,7 +226,7 @@ def __iter__(self) -> Iterator[Asset | AssetAlias]: def __len__(self) -> int: return len(self._dict) - def __getitem__(self, key: Asset | AssetAlias) -> OutletEventAccessor: + def __getitem__(self, key: Asset | AssetAlias | AssetRef) -> OutletEventAccessor: hashable_key: BaseAssetUniqueKey if isinstance(key, Asset): hashable_key = AssetUniqueKey.from_asset(key) @@ -284,6 +284,8 @@ def _get_asset_from_db(name: str | None = None, uri: str | None = None) -> Asset @attrs.define(init=False) class InletEventsAccessors(Mapping[Union[int, Asset, AssetAlias, AssetRef], Any]): + """Lazy mapping of inlet asset event accessors.""" + _inlets: list[Any] _assets: dict[AssetUniqueKey, Asset] _asset_aliases: dict[AssetAliasUniqueKey, AssetAlias] diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 59c15ea3723e4..4b42f8d9f8c2d 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -23,7 +23,7 @@ import functools import os import sys -from collections.abc import Iterable, Mapping +from collections.abc import Iterable, Iterator, Mapping from datetime import datetime, timezone from io import FileIO from pathlib import Path @@ -46,7 +46,7 @@ ) from airflow.sdk.definitions._internal.dag_parsing_context import _airflow_parsing_context_manager from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet -from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetNameRef, AssetUriRef +from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetNameRef, AssetUniqueKey, AssetUriRef from airflow.sdk.definitions.baseoperator import BaseOperator, ExecutorSafeguard from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.sdk.definitions.param import process_params @@ -515,33 +515,27 @@ def _serialize_rendered_fields(task: AbstractOperator) -> dict[str, JsonValue]: return {field: serialize_template_field(getattr(task, field), field) for field in task.template_fields} -def _process_outlets(context: Context, outlets: list[AssetProfile]): - added_alias_to_task_outlet = False - task_outlets: list[AssetProfile] = [] - outlet_events: list[Any] = [] - events = context["outlet_events"] - - for obj in outlets or []: - # Lineage can have other types of objects besides assets +def _build_asset_profiles(lineage_objects: list) -> Iterator[AssetProfile]: + # Lineage can have other types of objects besides assets, so we need to process them a bit. + for obj in lineage_objects or (): if isinstance(obj, Asset): - task_outlets.append(AssetProfile(name=obj.name, uri=obj.uri, type=Asset.__name__)) - outlet_events.append(attrs.asdict(events[obj])) # type: ignore + yield AssetProfile(name=obj.name, uri=obj.uri, type=Asset.__name__) elif isinstance(obj, AssetNameRef): - task_outlets.append(AssetProfile(name=obj.name, type=AssetNameRef.__name__)) - # Send all events, filtering can be done in API server. - outlet_events.append(attrs.asdict(events)) # type: ignore + yield AssetProfile(name=obj.name, type=AssetNameRef.__name__) elif isinstance(obj, AssetUriRef): - task_outlets.append(AssetProfile(uri=obj.uri, type=AssetUriRef.__name__)) - # Send all events, filtering can be done in API server. - outlet_events.append(attrs.asdict(events)) # type: ignore + yield AssetProfile(uri=obj.uri, type=AssetUriRef.__name__) elif isinstance(obj, AssetAlias): - if not added_alias_to_task_outlet: - task_outlets.append(AssetProfile(name=obj.name, type=AssetAlias.__name__)) - added_alias_to_task_outlet = True - for asset_alias_event in events[obj].asset_alias_events: - outlet_events.append(attrs.asdict(asset_alias_event)) + yield AssetProfile(name=obj.name, type=AssetAlias.__name__) + - return task_outlets, outlet_events +def _serialize_outlet_events(events: OutletEventAccessors) -> Iterator[dict[str, Any]]: + # We just collect everything the user recorded in the accessors. + # Further filtering will be done in the API server. + for key, accessor in events._dict.items(): + if isinstance(key, AssetUniqueKey): + yield {"dest_asset_key": key, "extra": accessor.extra} + for alias_event in accessor.asset_alias_events: + yield attrs.asdict(alias_event) def _prepare(ti: RuntimeTaskInstance, log: Logger, context: Context) -> ToSupervisor | None: @@ -610,11 +604,10 @@ def run( _push_xcom_if_needed(result, ti, log) - task_outlets, outlet_events = _process_outlets(context, ti.task.outlets) msg = SucceedTask( end_date=datetime.now(tz=timezone.utc), - task_outlets=task_outlets, - outlet_events=outlet_events, + task_outlets=list(_build_asset_profiles(ti.task.outlets)), + outlet_events=list(_serialize_outlet_events(context["outlet_events"])), # type: ignore ) state = TerminalTIState.SUCCESS except TaskDeferred as defer: diff --git a/task-sdk/src/airflow/sdk/types.py b/task-sdk/src/airflow/sdk/types.py index 8af24a328978a..9f325e7c05c67 100644 --- a/task-sdk/src/airflow/sdk/types.py +++ b/task-sdk/src/airflow/sdk/types.py @@ -20,13 +20,15 @@ from collections.abc import Iterable from typing import TYPE_CHECKING, Any, Protocol, Union +import attrs + from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet if TYPE_CHECKING: from collections.abc import Iterator from datetime import datetime - from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasEvent, BaseAssetUniqueKey + from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasEvent, AssetRef, BaseAssetUniqueKey from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.definitions.context import Context from airflow.sdk.definitions.mappedoperator import MappedOperator @@ -80,7 +82,7 @@ def xcom_push(self, key: str, value: Any) -> None: ... def get_template_context(self) -> Context: ... -class OutletEventAccessorProtocol(Protocol): +class OutletEventAccessorProtocol(Protocol, attrs.AttrsInstance): """Protocol for managing access to a specific outlet event accessor.""" key: BaseAssetUniqueKey @@ -102,4 +104,4 @@ class OutletEventAccessorsProtocol(Protocol): def __iter__(self) -> Iterator[Asset | AssetAlias]: ... def __len__(self) -> int: ... - def __getitem__(self, key: Asset | AssetAlias) -> OutletEventAccessorProtocol: ... + def __getitem__(self, key: Asset | AssetAlias | AssetRef) -> OutletEventAccessorProtocol: ... From 222bb17115d73bf129436f8b4812aa5260bb6678 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Thu, 13 Mar 2025 00:07:08 +0800 Subject: [PATCH 02/10] Reuse code in execution_time --- airflow/models/taskinstance.py | 35 +++++++------------ .../airflow/sdk/execution_time/task_runner.py | 7 ++-- 2 files changed, 17 insertions(+), 25 deletions(-) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 7ffacc12bdc9f..a1ec69246787c 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -103,7 +103,6 @@ from airflow.models.taskreschedule import TaskReschedule from airflow.models.xcom import LazyXComSelectSequence, XCom from airflow.plugins_manager import integrate_macros_plugins -from airflow.sdk.api.datamodels._generated import AssetProfile from airflow.sdk.definitions._internal.templater import SandboxedEnvironment from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetNameRef, AssetUniqueKey, AssetUriRef from airflow.sdk.definitions.param import process_params @@ -158,6 +157,7 @@ from airflow.models.baseoperator import BaseOperator from airflow.models.dag import DAG as SchedulerDAG, DagModel from airflow.models.dagrun import DagRun + from airflow.sdk.api.datamodels._generated import AssetProfile from airflow.sdk.definitions._internal.abstractoperator import Operator from airflow.sdk.definitions.dag import DAG from airflow.sdk.types import RuntimeTaskInstanceProtocol @@ -355,28 +355,17 @@ def _run_raw_task( if not test_mode: _add_log(event=ti.state, task_instance=ti, session=session) if ti.state == TaskInstanceState.SUCCESS: - added_alias_to_task_outlet = False - task_outlets = [] - outlet_events = [] - events = context["outlet_events"] - for obj in ti.task.outlets or []: - # Lineage can have other types of objects besides assets - if isinstance(obj, Asset): - task_outlets.append(AssetProfile(name=obj.name, uri=obj.uri, type=Asset.__name__)) - outlet_events.append(attrs.asdict(events[obj])) # type: ignore - elif isinstance(obj, AssetNameRef): - task_outlets.append(AssetProfile(name=obj.name, type=AssetNameRef.__name__)) - outlet_events.append(attrs.asdict(events)) # type: ignore - elif isinstance(obj, AssetUriRef): - task_outlets.append(AssetProfile(uri=obj.uri, type=AssetUriRef.__name__)) - outlet_events.append(attrs.asdict(events)) # type: ignore - elif isinstance(obj, AssetAlias): - if not added_alias_to_task_outlet: - task_outlets.append(AssetProfile(name=obj.name, type=AssetAlias.__name__)) - added_alias_to_task_outlet = True - for asset_alias_event in events[obj].asset_alias_events: - outlet_events.append(attrs.asdict(asset_alias_event)) - TaskInstance.register_asset_changes_in_db(ti, task_outlets, outlet_events, session=session) + from airflow.sdk.execution_time.task_runner import ( + _build_asset_profiles, + _serialize_outlet_events, + ) + + TaskInstance.register_asset_changes_in_db( + ti, + list(_build_asset_profiles(ti.task.outlets)), + list(_serialize_outlet_events(context["outlet_events"])), + session=session, + ) TaskInstance.save_to_db(ti=ti, session=session) if ti.state == TaskInstanceState.SUCCESS: diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 4b42f8d9f8c2d..2eea4dffdbdae 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -83,6 +83,7 @@ from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator from airflow.sdk.definitions.context import Context + from airflow.sdk.types import OutletEventAccessorsProtocol class TaskRunnerMarker: @@ -528,7 +529,9 @@ def _build_asset_profiles(lineage_objects: list) -> Iterator[AssetProfile]: yield AssetProfile(name=obj.name, type=AssetAlias.__name__) -def _serialize_outlet_events(events: OutletEventAccessors) -> Iterator[dict[str, Any]]: +def _serialize_outlet_events(events: OutletEventAccessorsProtocol) -> Iterator[dict[str, Any]]: + if TYPE_CHECKING: + assert isinstance(events, OutletEventAccessors) # We just collect everything the user recorded in the accessors. # Further filtering will be done in the API server. for key, accessor in events._dict.items(): @@ -607,7 +610,7 @@ def run( msg = SucceedTask( end_date=datetime.now(tz=timezone.utc), task_outlets=list(_build_asset_profiles(ti.task.outlets)), - outlet_events=list(_serialize_outlet_events(context["outlet_events"])), # type: ignore + outlet_events=list(_serialize_outlet_events(context["outlet_events"])), ) state = TerminalTIState.SUCCESS except TaskDeferred as defer: From c2f7f16eb8874d1eb2850ed4939c16e85b5af72d Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Thu, 13 Mar 2025 02:02:49 +0800 Subject: [PATCH 03/10] Fix tests on both sides of the API The task runner now emits slightly different data; extra is no longer eagerly resolved; the server side fills in the default automatically. Fixtures using asset refs now work correctly. The API server side tests are vastly improved to cover more kinds of data, including where events against other aliases are not incorrectly picked up when processing events originated from an alias. --- airflow/models/taskinstance.py | 9 +- .../execution_time/test_task_runner.py | 55 +------ .../routes/test_task_instances.py | 134 +++++++++++++----- 3 files changed, 107 insertions(+), 91 deletions(-) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index a1ec69246787c..7c61f5e1fc2b8 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -2806,15 +2806,16 @@ def asset_event_extras_from_aliases() -> dict[tuple[AssetUniqueKey, frozenset], d[asset_key, extra_key].add(alias_name) return d - outlet_alias_names = {o.name for o in task_outlets if o.type == AssetAlias.__name__ if o.name} + outlet_alias_names = {o.name for o in task_outlets if o.type == AssetAlias.__name__ and o.name} if outlet_alias_names and (event_from_aliases := asset_event_extras_from_aliases()): - asset_alias_models: dict[str, AssetAliasModel] = dict( - session.execute( + asset_alias_models: dict[str, AssetAliasModel] = { + name: model + for name, model in session.execute( select(AssetAliasModel.name, AssetAliasModel).where( AssetAliasModel.name.in_(outlet_alias_names) ) ) - ) + } for (asset_key, extra_key), event_aliase_names in event_from_aliases.items(): aliases = [ alias_model diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 50c89e047d22b..3400312d6b8bf 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -693,13 +693,7 @@ def test_dag_parsing_context(make_ti_context, mock_supervisor_comms, monkeypatch task_outlets=[ AssetProfile(name="s3://bucket/my-task", uri="s3://bucket/my-task", type="Asset") ], - outlet_events=[ - { - "key": {"name": "s3://bucket/my-task", "uri": "s3://bucket/my-task"}, - "extra": {}, - "asset_alias_events": [], - } - ], + outlet_events=[], ), id="asset", ), @@ -711,13 +705,7 @@ def test_dag_parsing_context(make_ti_context, mock_supervisor_comms, monkeypatch task_outlets=[ AssetProfile(name="s3://bucket/my-task", uri="s3://bucket/my-task", type="Asset") ], - outlet_events=[ - { - "key": {"name": "s3://bucket/my-task", "uri": "s3://bucket/my-task"}, - "extra": {}, - "asset_alias_events": [], - } - ], + outlet_events=[], ), id="dataset", ), @@ -729,13 +717,7 @@ def test_dag_parsing_context(make_ti_context, mock_supervisor_comms, monkeypatch task_outlets=[ AssetProfile(name="s3://bucket/my-task", uri="s3://bucket/my-task", type="Asset") ], - outlet_events=[ - { - "key": {"name": "s3://bucket/my-task", "uri": "s3://bucket/my-task"}, - "extra": {}, - "asset_alias_events": [], - } - ], + outlet_events=[], ), id="model", ), @@ -745,15 +727,8 @@ def test_dag_parsing_context(make_ti_context, mock_supervisor_comms, monkeypatch state="success", end_date=timezone.datetime(2024, 12, 3, 10, 0), task_outlets=[AssetProfile(name="s3://bucket/my-task", type="AssetNameRef")], - outlet_events=[ - { - "key": {"name": "s3://bucket/my-task"}, - "extra": {}, - "asset_alias_events": [], - } - ], + outlet_events=[], ), - marks=[pytest.mark.xfail], # Currently not handled correctly in task runner. id="name-ref", ), pytest.param( @@ -762,15 +737,8 @@ def test_dag_parsing_context(make_ti_context, mock_supervisor_comms, monkeypatch state="success", end_date=timezone.datetime(2024, 12, 3, 10, 0), task_outlets=[AssetProfile(uri="s3://bucket/my-task", type="AssetUriRef")], - outlet_events=[ - { - "key": {"uri": "s3://bucket/my-task"}, - "extra": {}, - "asset_alias_events": [], - } - ], + outlet_events=[], ), - marks=[pytest.mark.xfail], # Currently not handled correctly in task runner. id="uri-ref", ), pytest.param( @@ -864,18 +832,7 @@ def test_run_with_asset_inlets(create_runtime_ti, mock_supervisor_comms): AssetProfile(name="name", uri="s3://bucket/my-task", type="Asset"), AssetProfile(name="new-name", uri="s3://bucket/my-task", type="Asset"), ], - outlet_events=[ - { - "asset_alias_events": [], - "extra": {}, - "key": {"name": "name", "uri": "s3://bucket/my-task"}, - }, - { - "asset_alias_events": [], - "extra": {}, - "key": {"name": "new-name", "uri": "s3://bucket/my-task"}, - }, - ], + outlet_events=[], ), id="runtime_checks_pass", ), diff --git a/tests/api_fastapi/execution_api/routes/test_task_instances.py b/tests/api_fastapi/execution_api/routes/test_task_instances.py index 19701aa6a4b03..d583a26e88c3b 100644 --- a/tests/api_fastapi/execution_api/routes/test_task_instances.py +++ b/tests/api_fastapi/execution_api/routes/test_task_instances.py @@ -318,9 +318,11 @@ def test_xcom_not_cleared_for_deferral(self, client, session, create_task_instan class TestTIUpdateState: def setup_method(self): + clear_db_assets() clear_db_runs() def teardown_method(self): + clear_db_assets() clear_db_runs() @pytest.mark.parametrize( @@ -359,51 +361,45 @@ def test_ti_update_state_to_terminal( assert ti.end_date == end_date @pytest.mark.parametrize( - ("task_outlets", "outlet_events"), + "task_outlets", [ - ( - [{"name": "s3://bucket/my-task", "uri": "s3://bucket/my-task", "type": "Asset"}], + pytest.param([{"name": "my-task", "uri": "s3://bucket/my-task", "type": "Asset"}], id="asset"), + pytest.param([{"name": "my-task", "type": "AssetNameRef"}], id="name-ref"), + pytest.param([{"uri": "s3://bucket/my-task", "type": "AssetUriRef"}], id="uri-ref"), + ], + ) + @pytest.mark.parametrize( + "outlet_events, expected_extra", + [ + pytest.param([], {}, id="default"), + pytest.param( [ { - "key": {"name": "s3://bucket/my-task", "uri": "s3://bucket/my-task"}, - "extra": {}, - "asset_alias_events": [], - } - ], - ), - ( - [{"type": "AssetAlias"}], - [ + "dest_asset_key": {"name": "my-task", "uri": "s3://bucket/my-task"}, + "extra": {"foo": 1}, + }, { - "source_alias_name": "example-alias", - "dest_asset_key": {"name": "s3://bucket/my-task", "uri": "s3://bucket/my-task"}, - "extra": {}, - } + "dest_asset_key": {"name": "my-task-2", "uri": "s3://bucket/my-task-2"}, + "extra": {"foo": 2}, + }, ], + {"foo": 1}, + id="extra", ), ], ) def test_ti_update_state_to_success_with_asset_events( - self, client, session, create_task_instance, task_outlets, outlet_events + self, client, session, create_task_instance, task_outlets, outlet_events, expected_extra ): - clear_db_assets() - clear_db_runs() - asset = AssetModel( id=1, - name="s3://bucket/my-task", + name="my-task", uri="s3://bucket/my-task", group="asset", extra={}, ) asset_active = AssetActive.for_asset(asset) session.add_all([asset, asset_active]) - asset_type = task_outlets[0]["type"] - if asset_type == "AssetAlias": - _create_asset_aliases(session, num=1) - asset_alias = session.query(AssetAliasModel).all() - assert len(asset_alias) == 1 - assert asset_alias == [AssetAliasModel(name="simple1")] ti = create_task_instance( task_id="test_ti_update_state_to_success_with_asset_events", @@ -426,18 +422,80 @@ def test_ti_update_state_to_success_with_asset_events( assert response.text == "" session.expire_all() - # check if asset was created properly - asset = session.query(AssetModel).all() - assert len(asset) == 1 - assert asset == [AssetModel(name="s3://bucket/my-task", uri="s3://bucket/my-task", extra={})] - - event = session.query(AssetEvent).all() + event = session.scalars(select(AssetEvent)).all() assert len(event) == 1 - assert event[0].asset_id == 1 - assert event[0].asset == AssetModel(name="s3://bucket/my-task", uri="s3://bucket/my-task", extra={}) - assert event[0].extra == {} - if asset_type == "AssetAlias": - assert event[0].source_aliases == [AssetAliasModel(name="example-alias")] + assert event[0].asset == AssetModel(name="my-task", uri="s3://bucket/my-task", extra={}) + assert event[0].extra == expected_extra + + @pytest.mark.parametrize( + "outlet_events, expected_extra", + [ + pytest.param([], None, id="default"), + pytest.param( + [ + { + "dest_asset_key": {"name": "my-task", "uri": "s3://bucket/my-task"}, + "source_alias_name": "simple1", + "extra": {"foo": 1}, + }, + { + "dest_asset_key": {"name": "my-task-2", "uri": "s3://bucket/my-task-2"}, + "extra": {"foo": 2}, + }, + { + "dest_asset_key": {"name": "my-task-2", "uri": "s3://bucket/my-task-2"}, + "source_alias_name": "simple2", + "extra": {"foo": 3}, + }, + ], + {"foo": 1}, + id="extra", + ), + ], + ) + def test_ti_update_state_to_success_with_asset_alias_events( + self, client, session, create_task_instance, outlet_events, expected_extra + ): + asset = AssetModel( + id=1, + name="my-task", + uri="s3://bucket/my-task", + group="asset", + extra={}, + ) + asset_active = AssetActive.for_asset(asset) + session.add_all([asset, asset_active]) + + _create_asset_aliases(session, num=2) + + ti = create_task_instance( + task_id="test_ti_update_state_to_success_with_asset_events", + start_date=DEFAULT_START_DATE, + state=State.RUNNING, + ) + session.commit() + + response = client.patch( + f"/execution/task-instances/{ti.id}/state", + json={ + "state": "success", + "end_date": DEFAULT_END_DATE.isoformat(), + "task_outlets": [{"name": "simple1", "type": "AssetAlias"}], + "outlet_events": outlet_events, + }, + ) + + assert response.status_code == 204 + assert response.text == "" + session.expire_all() + + events = session.scalars(select(AssetEvent)).all() + if expected_extra is None: + assert events == [] + else: + assert len(events) == 1 + assert events[0].asset == AssetModel(name="my-task", uri="s3://bucket/my-task", extra={}) + assert events[0].extra == expected_extra def test_ti_update_state_not_found(self, client, session): """ From c0742d657af8075f4b14757a26c7f0dcfe2b0bb8 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Thu, 13 Mar 2025 13:06:34 +0800 Subject: [PATCH 04/10] Remove redundant arg from register_asset_change The 'aliases' and 'source_alias_names' basically serve the same purpose, but we use each for a different section of the code. They can become one. --- airflow/assets/manager.py | 5 ++--- airflow/models/taskinstance.py | 2 +- tests/assets/test_manager.py | 4 +--- 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/airflow/assets/manager.py b/airflow/assets/manager.py index cecd76b07c3f4..1b330a8d7cb26 100644 --- a/airflow/assets/manager.py +++ b/airflow/assets/manager.py @@ -110,8 +110,7 @@ def register_asset_change( task_instance: TaskInstance | None = None, asset: Asset | AssetModel | AssetUniqueKey, extra=None, - aliases: Collection[AssetAlias] = (), - source_alias_names: Iterable[str] | None = None, + source_alias_names: Collection[str] = (), session: Session, **kwargs, ) -> AssetEvent | None: @@ -136,7 +135,7 @@ def register_asset_change( return None cls._add_asset_alias_association( - alias_names={alias.name for alias in aliases}, asset_model=asset_model, session=session + alias_names=source_alias_names, asset_model=asset_model, session=session ) event_kwargs = { diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 7c61f5e1fc2b8..e8e9c2878a985 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -2828,7 +2828,7 @@ def asset_event_extras_from_aliases() -> dict[tuple[AssetUniqueKey, frozenset], asset_manager.register_asset_change( task_instance=ti, asset=asset_key, - aliases=aliases, + source_alias_names=event_aliase_names, extra=dict(extra_key), session=session, ) diff --git a/tests/assets/test_manager.py b/tests/assets/test_manager.py index b47dae7f2e7f1..3ffd14643eadd 100644 --- a/tests/assets/test_manager.py +++ b/tests/assets/test_manager.py @@ -34,7 +34,7 @@ DagScheduleAssetReference, ) from airflow.models.dag import DagModel -from airflow.sdk.definitions.asset import Asset, AssetAlias +from airflow.sdk.definitions.asset import Asset from tests.listeners import asset_listener @@ -124,12 +124,10 @@ def test_register_asset_change_with_alias(self, session, dag_maker, mock_task_in session.flush() asset = Asset(uri="test://asset1", name="test_asset_uri") - asset_alias = AssetAlias(name="test_alias_name", group="test") asset_manager = AssetManager() asset_manager.register_asset_change( task_instance=mock_task_instance, asset=asset, - aliases=[asset_alias], source_alias_names=["test_alias_name"], session=session, ) From 936b2e32af12a8d135dfac01c6c1de119590a512 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Thu, 13 Mar 2025 13:10:35 +0800 Subject: [PATCH 05/10] Check for inactive assets Exceptions are now raised if an event is raised against inactive assets. --- airflow/exceptions.py | 41 ++++-- airflow/models/taskinstance.py | 137 ++++++++++-------- .../airflow/sdk/definitions/asset/__init__.py | 2 +- .../airflow/sdk/execution_time/task_runner.py | 2 +- tests/models/test_taskinstance.py | 45 ++---- 5 files changed, 121 insertions(+), 106 deletions(-) diff --git a/airflow/exceptions.py b/airflow/exceptions.py index ca0eb6509eb6c..f9003a8eddc09 100644 --- a/airflow/exceptions.py +++ b/airflow/exceptions.py @@ -27,13 +27,13 @@ from http import HTTPStatus from typing import TYPE_CHECKING, Any, NamedTuple +from airflow.sdk.definitions.asset import AssetNameRef, AssetUniqueKey, AssetUriRef from airflow.utils.trigger_rule import TriggerRule if TYPE_CHECKING: from collections.abc import Sized from airflow.models import DagRun - from airflow.sdk.definitions.asset import AssetUniqueKey class AirflowException(Exception): @@ -113,33 +113,42 @@ class AirflowFailException(AirflowException): """Raise when the task should be failed without retrying.""" -class AirflowExecuteWithInactiveAssetExecption(AirflowFailException): +class _AirflowExecuteWithInactiveAssetExecption(AirflowFailException): """Raise when the task is executed with inactive assets.""" - def __init__(self, inactive_asset_unikeys: Collection[AssetUniqueKey]) -> None: - self.inactive_asset_unique_keys = inactive_asset_unikeys + main_message: str + + def __init__(self, inactive_asset_keys: Collection[AssetUniqueKey | AssetNameRef | AssetUriRef]) -> None: + self.inactive_asset_keys = inactive_asset_keys + + @staticmethod + def _render_asset_key(key: AssetUniqueKey | AssetNameRef | AssetUriRef) -> str: + if isinstance(key, AssetUniqueKey): + return f"Asset(name={key.name!r}, uri={key.uri!r})" + elif isinstance(key, AssetNameRef): + return f"Asset.ref(name={key.name!r})" + elif isinstance(key, AssetUriRef): + return f"Asset.ref(uri={key.uri!r})" + return repr(key) # Should not happen, but let's fails more gracefully in an exception. + + def __str__(self) -> str: + return f"{self.main_message}: {self.inactive_assets_message}" @property - def inactive_assets_error_msg(self): - return ", ".join( - f'Asset(name="{key.name}", uri="{key.uri}")' for key in self.inactive_asset_unique_keys - ) + def inactive_assets_message(self): + return ", ".join(self._render_asset_key(key) for key in self.inactive_asset_keys) -class AirflowInactiveAssetInInletOrOutletException(AirflowExecuteWithInactiveAssetExecption): +class AirflowInactiveAssetInInletOrOutletException(_AirflowExecuteWithInactiveAssetExecption): """Raise when the task is executed with inactive assets in its inlet or outlet.""" - def __str__(self) -> str: - return f"Task has the following inactive assets in its inlets or outlets: {self.inactive_assets_error_msg}" + main_message = "Task has the following inactive assets in its inlets or outlets" -class AirflowInactiveAssetAddedToAssetAliasException(AirflowExecuteWithInactiveAssetExecption): +class AirflowInactiveAssetAddedToAssetAliasException(_AirflowExecuteWithInactiveAssetExecption): """Raise when inactive assets are added to an asset alias.""" - def __str__(self) -> str: - return ( - f"The following assets accessed by an AssetAlias are inactive: {self.inactive_assets_error_msg}" - ) + main_message = "The following assets accessed by an AssetAlias are inactive" class AirflowOptionalProviderFeatureException(AirflowException): diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index e8e9c2878a985..1742a816c29a9 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -80,6 +80,7 @@ from airflow.exceptions import ( AirflowException, AirflowFailException, + AirflowInactiveAssetAddedToAssetAliasException, AirflowInactiveAssetInInletOrOutletException, AirflowRescheduleException, AirflowSensorTimeout, @@ -93,7 +94,7 @@ XComForMappingNotPushed, ) from airflow.listeners.listener import get_listener_manager -from airflow.models.asset import AssetActive, AssetAliasModel, AssetEvent, AssetModel +from airflow.models.asset import AssetActive, AssetEvent, AssetModel from airflow.models.base import Base, StringID, TaskInstanceDependencies from airflow.models.dagbag import DagBag from airflow.models.log import Log @@ -2735,66 +2736,91 @@ def register_asset_changes_in_db( outlet_events: list[dict[str, Any]], session: Session = NEW_SESSION, ) -> None: - asset_keys = {(o.name, o.uri) for o in task_outlets if o.type == Asset.__name__} - asset_name_refs = {o.name for o in task_outlets if o.type == AssetNameRef.__name__} - asset_uri_refs = {o.uri for o in task_outlets if o.type == AssetUriRef.__name__} - - @cache - def asset_event_extras_by_unique_key() -> dict[AssetUniqueKey, dict]: - return { - AssetUniqueKey(**event["dest_asset_key"]): event["extra"] - for event in outlet_events - if "source_alias_name" not in event - } + asset_keys = { + AssetUniqueKey(o.name, o.uri) + for o in task_outlets + if o.type == Asset.__name__ and o.name and o.uri + } + asset_name_refs = { + Asset.ref(name=o.name) for o in task_outlets if o.type == AssetNameRef.__name__ and o.name + } + asset_uri_refs = { + Asset.ref(uri=o.uri) for o in task_outlets if o.type == AssetUriRef.__name__ and o.uri + } - @cache - def asset_event_extras_by_name() -> dict[str, dict]: - return {key.name: extra for key, extra in asset_event_extras_by_unique_key().items()} + asset_models: dict[AssetUniqueKey, AssetModel] = { + AssetUniqueKey.from_asset(am): am + for am in session.scalars( + select(AssetModel).where( + AssetModel.active.has(), + or_( + tuple_(AssetModel.name, AssetModel.uri).in_(attrs.astuple(k) for k in asset_keys), + AssetModel.name.in_(r.name for r in asset_name_refs), + AssetModel.uri.in_(r.uri for r in asset_uri_refs), + ), + ) + ) + } - @cache - def asset_event_extras_by_uri() -> dict[str, dict]: - return {key.uri: extra for key, extra in asset_event_extras_by_unique_key().items()} + asset_event_extras: dict[AssetUniqueKey, dict] = { + AssetUniqueKey(**event["dest_asset_key"]): event["extra"] + for event in outlet_events + if "source_alias_name" not in event + } + + bad_asset_keys: set[AssetUniqueKey | AssetNameRef | AssetUriRef] = set() - asset_models: Iterable[AssetModel] - if asset_keys: - asset_models = session.scalars( - select(AssetModel).where(tuple_(AssetModel.name, AssetModel.uri).in_(asset_keys)) + for key in asset_keys: + try: + am = asset_models[key] + except KeyError: + bad_asset_keys.add(key) + continue + ti.log.debug("register event for asset %s", am) + asset_manager.register_asset_change( + task_instance=ti, + asset=am, + extra=asset_event_extras.get(key), + session=session, ) - for am in asset_models: - ti.log.debug("register event for asset %s", am) - asset_manager.register_asset_change( - task_instance=ti, - asset=am, - extra=asset_event_extras_by_unique_key().get(AssetUniqueKey.from_asset(am)), - session=session, - ) + if asset_name_refs: - asset_models = session.scalars( - select(AssetModel).where(AssetModel.name.in_(asset_name_refs), AssetModel.active.has()) - ) - for am in asset_models: + asset_models_by_name = {key.name: am for key, am in asset_models.items()} + asset_event_extras_by_name = {key.name: extra for key, extra in asset_event_extras.items()} + for nref in asset_name_refs: + try: + am = asset_models_by_name[nref.name] + except KeyError: + bad_asset_keys.add(nref) + continue ti.log.debug("register event for asset name ref %s", am) asset_manager.register_asset_change( task_instance=ti, asset=am, - extra=asset_event_extras_by_name().get(am.name), + extra=asset_event_extras_by_name.get(nref.name), session=session, ) if asset_uri_refs: - asset_models = session.scalars( - select(AssetModel).where(AssetModel.uri.in_(asset_uri_refs), AssetModel.active.has()) - ) - for am in asset_models: + asset_models_by_uri = {key.uri: am for key, am in asset_models.items()} + asset_event_extras_by_uri = {key.uri: extra for key, extra in asset_event_extras.items()} + for uref in asset_uri_refs: + try: + am = asset_models_by_uri[uref.uri] + except KeyError: + bad_asset_keys.add(uref) + continue ti.log.debug("register event for asset uri ref %s", am) asset_manager.register_asset_change( task_instance=ti, asset=am, - extra=asset_event_extras_by_uri().get(am.uri), + extra=asset_event_extras_by_uri.get(uref.uri), session=session, ) - @cache - def asset_event_extras_from_aliases() -> dict[tuple[AssetUniqueKey, frozenset], set[str]]: + if bad_asset_keys: + raise AirflowInactiveAssetInInletOrOutletException(bad_asset_keys) + + def _asset_event_extras_from_aliases() -> dict[tuple[AssetUniqueKey, frozenset], set[str]]: d = defaultdict(set) for event in outlet_events: try: @@ -2807,24 +2833,15 @@ def asset_event_extras_from_aliases() -> dict[tuple[AssetUniqueKey, frozenset], return d outlet_alias_names = {o.name for o in task_outlets if o.type == AssetAlias.__name__ and o.name} - if outlet_alias_names and (event_from_aliases := asset_event_extras_from_aliases()): - asset_alias_models: dict[str, AssetAliasModel] = { - name: model - for name, model in session.execute( - select(AssetAliasModel.name, AssetAliasModel).where( - AssetAliasModel.name.in_(outlet_alias_names) - ) - ) - } - for (asset_key, extra_key), event_aliase_names in event_from_aliases.items(): - aliases = [ - alias_model - for alias_model in (asset_alias_models.get(n) for n in event_aliase_names) - if alias_model is not None - ] - if not aliases: + if outlet_alias_names and (event_extras_from_aliases := _asset_event_extras_from_aliases()): + bad_alias_asset_keys = TaskInstance._get_inactive_asset_unique_keys( + {key for key, _ in event_extras_from_aliases}, + session=session, + ) + for (asset_key, extra_key), event_aliase_names in event_extras_from_aliases.items(): + if asset_key in bad_alias_asset_keys: continue - ti.log.debug("register event for asset %s with alias %s", asset_key, aliases) + ti.log.debug("register event for asset %s with aliases %s", asset_key, event_aliase_names) asset_manager.register_asset_change( task_instance=ti, asset=asset_key, @@ -2832,6 +2849,8 @@ def asset_event_extras_from_aliases() -> dict[tuple[AssetUniqueKey, frozenset], extra=dict(extra_key), session=session, ) + if bad_alias_asset_keys: + raise AirflowInactiveAssetAddedToAssetAliasException(bad_alias_asset_keys) def _execute_task_with_callbacks(self, context: Context, test_mode: bool = False, *, session: Session): """Prepare Task for Execution.""" diff --git a/task-sdk/src/airflow/sdk/definitions/asset/__init__.py b/task-sdk/src/airflow/sdk/definitions/asset/__init__.py index 32204bd4139eb..41e8cfd637df4 100644 --- a/task-sdk/src/airflow/sdk/definitions/asset/__init__.py +++ b/task-sdk/src/airflow/sdk/definitions/asset/__init__.py @@ -60,7 +60,7 @@ @attrs.define(frozen=True) -class AssetUniqueKey: +class AssetUniqueKey(attrs.AttrsInstance): """ Columns to identify an unique asset. diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 2eea4dffdbdae..e42d67580661f 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -536,7 +536,7 @@ def _serialize_outlet_events(events: OutletEventAccessorsProtocol) -> Iterator[d # Further filtering will be done in the API server. for key, accessor in events._dict.items(): if isinstance(key, AssetUniqueKey): - yield {"dest_asset_key": key, "extra": accessor.extra} + yield {"dest_asset_key": attrs.asdict(key), "extra": accessor.extra} for alias_event in accessor.asset_alias_events: yield attrs.asdict(alias_event) diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index c1db818548db3..38a6cfbf86840 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -2562,7 +2562,7 @@ def test_outlet_asset_alias_asset_not_exists(self, dag_maker, session): from airflow.sdk.definitions.asset import Asset, AssetAlias asset_alias_name = "test_outlet_asset_alias_asset_not_exists_asset_alias" - asset_uri = "did_not_exists" + asset_uri = "does_not_exist" with dag_maker(dag_id="producer_dag", schedule=None, serialized=True, session=session): @@ -2572,34 +2572,21 @@ def producer(*, outlet_events): producer() - dr: DagRun = dag_maker.create_dagrun() + (ti,) = dag_maker.create_dagrun().get_task_instances(session=session) - for ti in dr.get_task_instances(session=session): + with pytest.raises(AirflowInactiveAssetAddedToAssetAliasException) as ctx: ti.run(session=session) + assert str(ctx.value) == ( + "The following assets accessed by an AssetAlias are inactive: " + "Asset(name='does_not_exist', uri='does_not_exist')" + ) - producer_event = session.scalar(select(AssetEvent).where(AssetEvent.source_task_id == "producer")) - - assert producer_event.source_task_id == "producer" - assert producer_event.source_dag_id == "producer_dag" - assert producer_event.source_run_id == "test" - assert producer_event.source_map_index == -1 - assert producer_event.asset.uri == asset_uri - assert producer_event.extra == {"key": "value"} - assert len(producer_event.source_aliases) == 1 - assert producer_event.source_aliases[0].name == asset_alias_name - - asset_obj = session.scalar(select(AssetModel).where(AssetModel.uri == asset_uri)) - assert len(asset_obj.aliases) == 1 - assert asset_obj.aliases[0].name == asset_alias_name - - asset_alias_obj = session.scalar(select(AssetAliasModel)) - assert len(asset_alias_obj.assets) == 1 - assert asset_alias_obj.assets[0].uri == asset_uri + assert session.scalar(select(AssetEvent)) is None def test_outlet_asset_alias_asset_inactive(self, dag_maker, session): from airflow.sdk.definitions.asset import Asset, AssetAlias - asset_name = "did_not_exists" + asset_name = "did_not_exist" asset = Asset(asset_name) asset2 = Asset(asset_name, uri="test://asset") asm = AssetModel.from_public(asset) @@ -2625,7 +2612,7 @@ def producer_with_inactive(*, outlet_events): with pytest.raises(AirflowInactiveAssetAddedToAssetAliasException) as exc: tis["producer_with_inactive"].run(session=session) - assert 'Asset(name="did_not_exists", uri="test://asset/")' in str(exc.value) + assert "Asset(name='did_not_exist', uri='test://asset/')" in str(exc.value) producer_event = session.scalar( select(AssetEvent).where(AssetEvent.source_task_id == "producer_without_inactive") @@ -4083,8 +4070,8 @@ def duplicate_asset_task_in_outlet(*, outlet_events): with pytest.raises(AirflowInactiveAssetInInletOrOutletException) as exc: tis["duplicate_asset_task_in_outlet"].run(session=session) - assert 'Asset(name="asset_second", uri="asset_second")' in str(exc.value) - assert 'Asset(name="asset_first", uri="test://asset/")' in str(exc.value) + assert "Asset(name='asset_second', uri='asset_second')" in str(exc.value) + assert "Asset(name='asset_first', uri='test://asset/')" in str(exc.value) @pytest.mark.want_activate_assets(True) def test_run_with_inactive_assets_in_outlets_within_the_same_dag(self, dag_maker, session): @@ -4109,7 +4096,7 @@ def duplicate_asset_task(*, outlet_events): assert str(exc.value) == ( "Task has the following inactive assets in its inlets or outlets: " - 'Asset(name="asset_first", uri="test://asset/")' + "Asset(name='asset_first', uri='test://asset/')" ) @pytest.mark.want_activate_assets(True) @@ -4138,7 +4125,7 @@ def duplicate_asset_task(*, outlet_events): assert str(exc.value) == ( "Task has the following inactive assets in its inlets or outlets: " - 'Asset(name="asset_first", uri="test://asset/")' + "Asset(name='asset_first', uri='test://asset/')" ) @pytest.mark.want_activate_assets(True) @@ -4163,7 +4150,7 @@ def duplicate_asset_task(): assert str(exc.value) == ( "Task has the following inactive assets in its inlets or outlets: " - 'Asset(name="asset_first", uri="asset_first")' + "Asset(name='asset_first', uri='asset_first')" ) @pytest.mark.want_activate_assets(True) @@ -4192,7 +4179,7 @@ def duplicate_asset_task(*, outlet_events): assert str(exc.value) == ( "Task has the following inactive assets in its inlets or outlets: " - 'Asset(name="asset_first", uri="test://asset/")' + "Asset(name='asset_first', uri='test://asset/')" ) From 9d6883eb45f4af110306c279b0369b5b8e9a3b66 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Thu, 13 Mar 2025 14:11:36 +0800 Subject: [PATCH 06/10] Correctly activate asset in test --- .../tests/unit/standard/decorators/test_python.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/providers/standard/tests/unit/standard/decorators/test_python.py b/providers/standard/tests/unit/standard/decorators/test_python.py index 314f96bf2154a..e637f66b5536b 100644 --- a/providers/standard/tests/unit/standard/decorators/test_python.py +++ b/providers/standard/tests/unit/standard/decorators/test_python.py @@ -26,6 +26,7 @@ from airflow.decorators import setup, task as task_decorator, teardown from airflow.decorators.base import DecoratedMappedOperator from airflow.exceptions import AirflowException, XComNotFound +from airflow.models.asset import AssetActive, AssetModel from airflow.models.taskinstance import TaskInstance from airflow.models.taskmap import TaskMap from airflow.utils import timezone @@ -977,13 +978,18 @@ def test_task_decorator_asset(dag_maker, session): uri = "s3://bucket/name" asset_name = "test_asset" + if AIRFLOW_V_3_0_PLUS: + asset = Asset(uri=uri, name=asset_name) + else: + asset = Asset(uri) + session.add(AssetModel.from_public(asset)) + session.add(AssetActive.for_asset(asset)) + with dag_maker(session=session) as dag: @dag.task() def up1() -> Asset: - if not AIRFLOW_V_3_0_PLUS: - return Asset(uri=uri) - return Asset(uri=uri, name=asset_name) + return asset @dag.task() def up2(src: Asset) -> str: From 7cdb0376659002c5226a44f684805bc995184fd8 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Thu, 13 Mar 2025 14:16:26 +0800 Subject: [PATCH 07/10] Ignore events against undeclared aliases --- airflow/models/taskinstance.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 1742a816c29a9..ac5e33953f44b 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -2827,6 +2827,8 @@ def _asset_event_extras_from_aliases() -> dict[tuple[AssetUniqueKey, frozenset], alias_name = event["source_alias_name"] except KeyError: continue + if alias_name not in outlet_alias_names: + continue asset_key = AssetUniqueKey(**event["dest_asset_key"]) extra_key = frozenset(event["extra"].items()) d[asset_key, extra_key].add(alias_name) From 37ce73d63ce64542170770f7c041226b2f192979 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Thu, 13 Mar 2025 15:01:41 +0800 Subject: [PATCH 08/10] Fix 2.x compat in Standard provider test --- .../standard/tests/unit/standard/decorators/test_python.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/providers/standard/tests/unit/standard/decorators/test_python.py b/providers/standard/tests/unit/standard/decorators/test_python.py index e637f66b5536b..727bfe7f685c9 100644 --- a/providers/standard/tests/unit/standard/decorators/test_python.py +++ b/providers/standard/tests/unit/standard/decorators/test_python.py @@ -26,7 +26,6 @@ from airflow.decorators import setup, task as task_decorator, teardown from airflow.decorators.base import DecoratedMappedOperator from airflow.exceptions import AirflowException, XComNotFound -from airflow.models.asset import AssetActive, AssetModel from airflow.models.taskinstance import TaskInstance from airflow.models.taskmap import TaskMap from airflow.utils import timezone @@ -970,9 +969,11 @@ def other(x): ... def test_task_decorator_asset(dag_maker, session): if AIRFLOW_V_3_0_PLUS: + from airflow.models.asset import AssetActive, AssetModel from airflow.sdk.definitions.asset import Asset else: from airflow.datasets import Dataset as Asset + from airflow.models.dataset import DatasetModel as AssetModel result = None uri = "s3://bucket/name" @@ -983,7 +984,8 @@ def test_task_decorator_asset(dag_maker, session): else: asset = Asset(uri) session.add(AssetModel.from_public(asset)) - session.add(AssetActive.for_asset(asset)) + if AIRFLOW_V_3_0_PLUS: + session.add(AssetActive.for_asset(asset)) with dag_maker(session=session) as dag: From 154b297893dbb0a45e5821bb8e2d55a960a71d00 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Thu, 13 Mar 2025 16:03:06 +0800 Subject: [PATCH 09/10] Trigger alias events when invalid assets found This does not really matter in practice since inactive assets should have been caught before the task runs, but in case any are found, we can still try to trigger as many valid events as possible before failing the task. --- airflow/models/taskinstance.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index ac5e33953f44b..939d3af021b91 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -2817,9 +2817,6 @@ def register_asset_changes_in_db( session=session, ) - if bad_asset_keys: - raise AirflowInactiveAssetInInletOrOutletException(bad_asset_keys) - def _asset_event_extras_from_aliases() -> dict[tuple[AssetUniqueKey, frozenset], set[str]]: d = defaultdict(set) for event in outlet_events: @@ -2854,6 +2851,9 @@ def _asset_event_extras_from_aliases() -> dict[tuple[AssetUniqueKey, frozenset], if bad_alias_asset_keys: raise AirflowInactiveAssetAddedToAssetAliasException(bad_alias_asset_keys) + if bad_asset_keys: + raise AirflowInactiveAssetInInletOrOutletException(bad_asset_keys) + def _execute_task_with_callbacks(self, context: Context, test_mode: bool = False, *, session: Session): """Prepare Task for Execution.""" if TYPE_CHECKING: From 7458bd9a7d91177614d092e73c22b94af0f05e1c Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Thu, 13 Mar 2025 16:05:23 +0800 Subject: [PATCH 10/10] Nits on exception classes --- airflow/exceptions.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/airflow/exceptions.py b/airflow/exceptions.py index f9003a8eddc09..0d1eea0051859 100644 --- a/airflow/exceptions.py +++ b/airflow/exceptions.py @@ -114,8 +114,6 @@ class AirflowFailException(AirflowException): class _AirflowExecuteWithInactiveAssetExecption(AirflowFailException): - """Raise when the task is executed with inactive assets.""" - main_message: str def __init__(self, inactive_asset_keys: Collection[AssetUniqueKey | AssetNameRef | AssetUriRef]) -> None: @@ -135,7 +133,7 @@ def __str__(self) -> str: return f"{self.main_message}: {self.inactive_assets_message}" @property - def inactive_assets_message(self): + def inactive_assets_message(self) -> str: return ", ".join(self._render_asset_key(key) for key in self.inactive_asset_keys)