Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class TISuccessStatePayload(StrictBaseModel):
"""When the task completed executing"""

task_outlets: Annotated[list[AssetProfile], Field(default_factory=list)]
outlet_events: Annotated[list[Any], Field(default_factory=list)]
outlet_events: Annotated[list[dict[str, Any]], Field(default_factory=list)]


class TITargetStatePayload(StrictBaseModel):
Expand Down
5 changes: 2 additions & 3 deletions airflow/assets/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,7 @@ def register_asset_change(
task_instance: TaskInstance | None = None,
asset: Asset | AssetModel | AssetUniqueKey,
extra=None,
aliases: Collection[AssetAlias] = (),
source_alias_names: Iterable[str] | None = None,
source_alias_names: Collection[str] = (),
session: Session,
**kwargs,
) -> AssetEvent | None:
Expand All @@ -136,7 +135,7 @@ def register_asset_change(
return None

cls._add_asset_alias_association(
alias_names={alias.name for alias in aliases}, asset_model=asset_model, session=session
alias_names=source_alias_names, asset_model=asset_model, session=session
)

event_kwargs = {
Expand Down
41 changes: 24 additions & 17 deletions airflow/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@
from http import HTTPStatus
from typing import TYPE_CHECKING, Any, NamedTuple

from airflow.sdk.definitions.asset import AssetNameRef, AssetUniqueKey, AssetUriRef
from airflow.utils.trigger_rule import TriggerRule

if TYPE_CHECKING:
from collections.abc import Sized

from airflow.models import DagRun
from airflow.sdk.definitions.asset import AssetUniqueKey


class AirflowException(Exception):
Expand Down Expand Up @@ -113,33 +113,40 @@ class AirflowFailException(AirflowException):
"""Raise when the task should be failed without retrying."""


class AirflowExecuteWithInactiveAssetExecption(AirflowFailException):
"""Raise when the task is executed with inactive assets."""
class _AirflowExecuteWithInactiveAssetExecption(AirflowFailException):
main_message: str

def __init__(self, inactive_asset_unikeys: Collection[AssetUniqueKey]) -> None:
self.inactive_asset_unique_keys = inactive_asset_unikeys
def __init__(self, inactive_asset_keys: Collection[AssetUniqueKey | AssetNameRef | AssetUriRef]) -> None:
self.inactive_asset_keys = inactive_asset_keys

@staticmethod
def _render_asset_key(key: AssetUniqueKey | AssetNameRef | AssetUriRef) -> str:
if isinstance(key, AssetUniqueKey):
return f"Asset(name={key.name!r}, uri={key.uri!r})"
elif isinstance(key, AssetNameRef):
return f"Asset.ref(name={key.name!r})"
elif isinstance(key, AssetUriRef):
return f"Asset.ref(uri={key.uri!r})"
return repr(key) # Should not happen, but let's fails more gracefully in an exception.

def __str__(self) -> str:
return f"{self.main_message}: {self.inactive_assets_message}"

@property
def inactive_assets_error_msg(self):
return ", ".join(
f'Asset(name="{key.name}", uri="{key.uri}")' for key in self.inactive_asset_unique_keys
)
def inactive_assets_message(self) -> str:
return ", ".join(self._render_asset_key(key) for key in self.inactive_asset_keys)


class AirflowInactiveAssetInInletOrOutletException(AirflowExecuteWithInactiveAssetExecption):
class AirflowInactiveAssetInInletOrOutletException(_AirflowExecuteWithInactiveAssetExecption):
"""Raise when the task is executed with inactive assets in its inlet or outlet."""

def __str__(self) -> str:
return f"Task has the following inactive assets in its inlets or outlets: {self.inactive_assets_error_msg}"
main_message = "Task has the following inactive assets in its inlets or outlets"


class AirflowInactiveAssetAddedToAssetAliasException(AirflowExecuteWithInactiveAssetExecption):
class AirflowInactiveAssetAddedToAssetAliasException(_AirflowExecuteWithInactiveAssetExecption):
"""Raise when inactive assets are added to an asset alias."""

def __str__(self) -> str:
return (
f"The following assets accessed by an AssetAlias are inactive: {self.inactive_assets_error_msg}"
)
main_message = "The following assets accessed by an AssetAlias are inactive"


class AirflowOptionalProviderFeatureException(AirflowException):
Expand Down
227 changes: 119 additions & 108 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@
from airflow.models.taskreschedule import TaskReschedule
from airflow.models.xcom import LazyXComSelectSequence, XCom
from airflow.plugins_manager import integrate_macros_plugins
from airflow.sdk.api.datamodels._generated import AssetProfile
from airflow.sdk.definitions._internal.templater import SandboxedEnvironment
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetNameRef, AssetUniqueKey, AssetUriRef
from airflow.sdk.definitions.param import process_params
Expand Down Expand Up @@ -159,6 +158,7 @@
from airflow.models.baseoperator import BaseOperator
from airflow.models.dag import DAG as SchedulerDAG, DagModel
from airflow.models.dagrun import DagRun
from airflow.sdk.api.datamodels._generated import AssetProfile
from airflow.sdk.definitions._internal.abstractoperator import Operator
from airflow.sdk.definitions.dag import DAG
from airflow.sdk.types import RuntimeTaskInstanceProtocol
Expand Down Expand Up @@ -356,28 +356,17 @@ def _run_raw_task(
if not test_mode:
_add_log(event=ti.state, task_instance=ti, session=session)
if ti.state == TaskInstanceState.SUCCESS:
added_alias_to_task_outlet = False
task_outlets = []
outlet_events = []
events = context["outlet_events"]
for obj in ti.task.outlets or []:
# Lineage can have other types of objects besides assets
if isinstance(obj, Asset):
task_outlets.append(AssetProfile(name=obj.name, uri=obj.uri, type=Asset.__name__))
outlet_events.append(attrs.asdict(events[obj])) # type: ignore
elif isinstance(obj, AssetNameRef):
task_outlets.append(AssetProfile(name=obj.name, type=AssetNameRef.__name__))
outlet_events.append(attrs.asdict(events)) # type: ignore
elif isinstance(obj, AssetUriRef):
task_outlets.append(AssetProfile(uri=obj.uri, type=AssetUriRef.__name__))
outlet_events.append(attrs.asdict(events)) # type: ignore
elif isinstance(obj, AssetAlias):
if not added_alias_to_task_outlet:
task_outlets.append(AssetProfile(name=obj.name, type=AssetAlias.__name__))
added_alias_to_task_outlet = True
for asset_alias_event in events[obj].asset_alias_events:
outlet_events.append(attrs.asdict(asset_alias_event))
TaskInstance.register_asset_changes_in_db(ti, task_outlets, outlet_events, session=session)
from airflow.sdk.execution_time.task_runner import (
_build_asset_profiles,
_serialize_outlet_events,
)

TaskInstance.register_asset_changes_in_db(
ti,
list(_build_asset_profiles(ti.task.outlets)),
list(_serialize_outlet_events(context["outlet_events"])),
session=session,
)

TaskInstance.save_to_db(ti=ti, session=session)
if ti.state == TaskInstanceState.SUCCESS:
Expand Down Expand Up @@ -2744,104 +2733,126 @@ def _run_raw_task(
def register_asset_changes_in_db(
ti: TaskInstance,
task_outlets: list[AssetProfile],
outlet_events: list[Any],
outlet_events: list[dict[str, Any]],
session: Session = NEW_SESSION,
) -> None:
# One task only triggers one asset event for each asset with the same extra.
# This tuple[asset uri, extra] to sets alias names mapping is used to find whether
# there're assets with same uri but different extra that we need to emit more than one asset events.
asset_alias_names: dict[tuple[AssetUniqueKey, frozenset], set[str]] = defaultdict(set)
asset_name_refs: set[str] = set()
asset_uri_refs: set[str] = set()

for obj in task_outlets:
ti.log.debug("outlet obj %s", obj)
# Lineage can have other types of objects besides assets
if obj.type == Asset.__name__:
asset_manager.register_asset_change(
task_instance=ti,
asset=Asset(name=obj.name, uri=obj.uri), # type: ignore
extra=outlet_events[0]["extra"],
session=session,
)
elif obj.type == AssetNameRef.__name__:
asset_name_refs.add(obj.name) # type: ignore
elif obj.type == AssetUriRef.__name__:
asset_uri_refs.add(obj.uri) # type: ignore
elif obj.type == AssetAlias.__name__:
outlet_events = list(
map(
lambda event: {**event, "dest_asset_key": AssetUniqueKey(**event["dest_asset_key"])},
outlet_events,
)
)
for asset_alias_event in outlet_events:
asset_alias_name = asset_alias_event["source_alias_name"]
asset_unique_key = asset_alias_event["dest_asset_key"]
frozen_extra = frozenset(asset_alias_event["extra"].items())
asset_alias_names[(asset_unique_key, frozen_extra)].add(asset_alias_name)

asset_unique_keys = {key for key, _ in asset_alias_names}
existing_aliased_assets: set[AssetUniqueKey] = {
AssetUniqueKey.from_asset(asset_obj)
for asset_obj in session.scalars(
asset_keys = {
AssetUniqueKey(o.name, o.uri)
for o in task_outlets
if o.type == Asset.__name__ and o.name and o.uri
}
asset_name_refs = {
Asset.ref(name=o.name) for o in task_outlets if o.type == AssetNameRef.__name__ and o.name
}
asset_uri_refs = {
Asset.ref(uri=o.uri) for o in task_outlets if o.type == AssetUriRef.__name__ and o.uri
}

asset_models: dict[AssetUniqueKey, AssetModel] = {
AssetUniqueKey.from_asset(am): am
for am in session.scalars(
select(AssetModel).where(
tuple_(AssetModel.name, AssetModel.uri).in_(
attrs.astuple(key) for key in asset_unique_keys
)
AssetModel.active.has(),
or_(
tuple_(AssetModel.name, AssetModel.uri).in_(attrs.astuple(k) for k in asset_keys),
AssetModel.name.in_(r.name for r in asset_name_refs),
AssetModel.uri.in_(r.uri for r in asset_uri_refs),
),
)
)
}
inactive_asset_unique_keys = TaskInstance._get_inactive_asset_unique_keys(
asset_unique_keys={key for key in asset_unique_keys if key in existing_aliased_assets},
session=session,
)
if inactive_asset_unique_keys:
raise AirflowInactiveAssetAddedToAssetAliasException(inactive_asset_unique_keys)

if missing_assets := [
asset_unique_key.to_asset()
for asset_unique_key, _ in asset_alias_names
if asset_unique_key not in existing_aliased_assets
]:
asset_manager.create_assets(missing_assets, session=session)
ti.log.warning("Created new assets for alias reference: %s", missing_assets)
session.flush() # Needed because we need the id for fk.

for (unique_key, extra_items), alias_names in asset_alias_names.items():
ti.log.info(
'Creating event for %r through aliases "%s"',
unique_key,
", ".join(alias_names),
)
asset_manager.register_asset_change(
task_instance=ti,
asset=unique_key,
aliases=[AssetAlias(name=name) for name in alias_names],
extra=dict(extra_items),
session=session,
source_alias_names=alias_names,
)

# Handle events derived from references.
asset_stmt = select(AssetModel).where(AssetModel.name.in_(asset_name_refs), AssetModel.active.has())
for asset_model in session.scalars(asset_stmt):
ti.log.info("Creating event through asset name reference %r", asset_model.name)
asset_event_extras: dict[AssetUniqueKey, dict] = {
AssetUniqueKey(**event["dest_asset_key"]): event["extra"]
for event in outlet_events
if "source_alias_name" not in event
}

bad_asset_keys: set[AssetUniqueKey | AssetNameRef | AssetUriRef] = set()

for key in asset_keys:
try:
am = asset_models[key]
except KeyError:
bad_asset_keys.add(key)
continue
ti.log.debug("register event for asset %s", am)
asset_manager.register_asset_change(
task_instance=ti,
asset=asset_model,
extra=outlet_events[asset_model].extra,
asset=am,
extra=asset_event_extras.get(key),
session=session,
)
asset_stmt = select(AssetModel).where(AssetModel.uri.in_(asset_uri_refs), AssetModel.active.has())
for asset_model in session.scalars(asset_stmt):
ti.log.info("Creating event for through asset URI reference %r", asset_model.uri)
asset_manager.register_asset_change(
task_instance=ti,
asset=asset_model,
extra=outlet_events[asset_model].extra,

if asset_name_refs:
asset_models_by_name = {key.name: am for key, am in asset_models.items()}
asset_event_extras_by_name = {key.name: extra for key, extra in asset_event_extras.items()}
for nref in asset_name_refs:
try:
am = asset_models_by_name[nref.name]
except KeyError:
bad_asset_keys.add(nref)
continue
ti.log.debug("register event for asset name ref %s", am)
asset_manager.register_asset_change(
task_instance=ti,
asset=am,
extra=asset_event_extras_by_name.get(nref.name),
session=session,
)
if asset_uri_refs:
asset_models_by_uri = {key.uri: am for key, am in asset_models.items()}
asset_event_extras_by_uri = {key.uri: extra for key, extra in asset_event_extras.items()}
for uref in asset_uri_refs:
try:
am = asset_models_by_uri[uref.uri]
except KeyError:
bad_asset_keys.add(uref)
continue
ti.log.debug("register event for asset uri ref %s", am)
asset_manager.register_asset_change(
task_instance=ti,
asset=am,
extra=asset_event_extras_by_uri.get(uref.uri),
session=session,
)

def _asset_event_extras_from_aliases() -> dict[tuple[AssetUniqueKey, frozenset], set[str]]:
d = defaultdict(set)
for event in outlet_events:
try:
alias_name = event["source_alias_name"]
except KeyError:
continue
if alias_name not in outlet_alias_names:
continue
asset_key = AssetUniqueKey(**event["dest_asset_key"])
extra_key = frozenset(event["extra"].items())
d[asset_key, extra_key].add(alias_name)
return d

outlet_alias_names = {o.name for o in task_outlets if o.type == AssetAlias.__name__ and o.name}
if outlet_alias_names and (event_extras_from_aliases := _asset_event_extras_from_aliases()):
bad_alias_asset_keys = TaskInstance._get_inactive_asset_unique_keys(
{key for key, _ in event_extras_from_aliases},
session=session,
)
for (asset_key, extra_key), event_aliase_names in event_extras_from_aliases.items():
if asset_key in bad_alias_asset_keys:
continue
ti.log.debug("register event for asset %s with aliases %s", asset_key, event_aliase_names)
asset_manager.register_asset_change(
task_instance=ti,
asset=asset_key,
source_alias_names=event_aliase_names,
extra=dict(extra_key),
session=session,
)
if bad_alias_asset_keys:
raise AirflowInactiveAssetAddedToAssetAliasException(bad_alias_asset_keys)

if bad_asset_keys:
raise AirflowInactiveAssetInInletOrOutletException(bad_asset_keys)

def _execute_task_with_callbacks(self, context: Context, test_mode: bool = False, *, session: Session):
"""Prepare Task for Execution."""
Expand Down
Loading
Loading