diff --git a/airflow-core/src/airflow/example_dags/example_asset_partition_at_runtime.py b/airflow-core/src/airflow/example_dags/example_asset_partition_at_runtime.py new file mode 100644 index 0000000000000..a0d46bf4c7d85 --- /dev/null +++ b/airflow-core/src/airflow/example_dags/example_asset_partition_at_runtime.py @@ -0,0 +1,207 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example Dags demonstrating runtime partition key assignment. + +These Dags showcase :class:`~airflow.sdk.PartitionAtRuntime` and the new +``partition_keys`` API on outlet events, where partition keys are discovered +and emitted at task execution time rather than being derived from a schedule. + +Three patterns are shown: + +1. **Single runtime key** — a task resolves which partition to process and sets + exactly one key, which is also recorded on the Dag run. + +2. **Fan-out** — a task discovers a dynamic set of partitions at runtime and + emits one asset event per partition, triggering a separate downstream Dag run + for each. + +3. **Fan-out with per-partition metadata** — same as fan-out but each + :class:`~airflow.sdk.PartitionKey` also carries extra metadata that is + merged into the corresponding asset event. +""" + +from __future__ import annotations + +from airflow.sdk import ( + DAG, + AllowedKeyMapper, + Asset, + PartitionAtRuntime, + PartitionedAssetTimetable, + PartitionKey, + asset, + task, +) + +# --------------------------------------------------------------------------- +# Pattern 1: single runtime partition key +# --------------------------------------------------------------------------- + +single_key_asset = Asset( + uri="file://incoming/daily-report.csv", + name="daily_report", +) + + +with DAG( + dag_id="produce_daily_report", + schedule=PartitionAtRuntime(), + tags=["runtime-partition", "single-key"], +): + """ + Produce a daily report whose partition date is resolved at runtime. + + ``schedule=PartitionAtRuntime()`` signals that this Dag is externally + triggered and that a task will set the partition key during execution. + Setting exactly one key also records it on the Dag run, so downstream + :class:`~airflow.sdk.PartitionedAssetTimetable` Dags receive the correct + partition context. + """ + + @task(outlets=[single_key_asset]) + def generate_report(*, outlet_events): + """Resolve today's report partition and emit a single asset event.""" + # In practice, compute the partition from business logic or + # an upstream API response. + report_date = "2024-06-15" + outlet_events[single_key_asset].partition_keys = [report_date] + + generate_report() + + +@asset( + uri="file://analytics/daily-report-summary.csv", + name="daily_report_summary", + schedule=PartitionedAssetTimetable(assets=single_key_asset), + tags=["runtime-partition", "single-key"], +) +def summarise_daily_report(self): + """ + Summarise a daily report partition. + + Triggered by ``daily_report`` with the partition key set at runtime by + ``produce_daily_report``. + """ + pass + + +# --------------------------------------------------------------------------- +# Pattern 2: fan-out — one event per discovered partition +# --------------------------------------------------------------------------- + +region_snapshots = Asset( + uri="file://incoming/snapshots/regions.csv", + name="region_snapshots", +) + + +with DAG( + dag_id="discover_and_snapshot_regions", + schedule=PartitionAtRuntime(), + tags=["runtime-partition", "fan-out"], +): + """ + Discover active regions at runtime and snapshot each one. + + The task queries an external source to find which regions are active, + then sets ``partition_keys`` to the full list. The scheduler creates one + downstream Dag run per key via the fan-out mechanism. + """ + + @task(outlets=[region_snapshots]) + def snapshot_regions(*, outlet_events): + """Emit one asset event per discovered region.""" + # Regions could come from an API call, a database query, etc. + active_regions = ["us-east", "eu-west", "ap-south"] + outlet_events[region_snapshots].partition_keys = active_regions + + snapshot_regions() + + +@asset( + uri="file://analytics/region-aggregation.csv", + name="region_aggregation", + schedule=PartitionedAssetTimetable( + assets=region_snapshots, + default_partition_mapper=AllowedKeyMapper(["us-east", "eu-west", "ap-south"]), + ), + tags=["runtime-partition", "fan-out"], +) +def aggregate_region(self): + """ + Aggregate data for a single region partition. + + One run of this asset is triggered for each region key emitted by + ``discover_and_snapshot_regions``. + """ + pass + + +# --------------------------------------------------------------------------- +# Pattern 3: fan-out with per-partition metadata via PartitionKey +# --------------------------------------------------------------------------- + +raw_feed = Asset( + uri="file://incoming/feeds/raw.csv", + name="raw_feed", +) + + +with DAG( + dag_id="ingest_feeds_with_metadata", + schedule=PartitionAtRuntime(), + tags=["runtime-partition", "fan-out", "partition-key-extra"], +): + """ + Ingest data feeds discovered at runtime, attaching source metadata to each. + + :class:`~airflow.sdk.PartitionKey` objects let you carry per-partition + ``extra`` data alongside the key. That metadata is merged into the asset + event and made available to downstream tasks via ``inlet_events``. + """ + + @task(outlets=[raw_feed]) + def ingest_feeds(*, outlet_events): + """Discover feeds and emit one partitioned event per feed with source metadata.""" + feeds = [ + ("feed_a", "s3://bucket/feed_a/latest.csv"), + ("feed_b", "s3://bucket/feed_b/latest.csv"), + ("feed_c", "s3://bucket/feed_c/latest.csv"), + ] + outlet_events[raw_feed].partition_keys = [ + PartitionKey(key=feed_id, extra={"source_uri": uri}) for feed_id, uri in feeds + ] + + ingest_feeds() + + +@asset( + uri="file://analytics/feeds/processed.csv", + name="processed_feed", + schedule=PartitionedAssetTimetable(assets=raw_feed), + tags=["runtime-partition", "fan-out", "partition-key-extra"], +) +def process_feed(self): + """ + Process a single feed partition. + + One run of this asset is triggered per feed key. The ``source_uri`` + attached via :class:`~airflow.sdk.PartitionKey` is available on the + triggering asset event in ``inlet_events``. + """ + pass diff --git a/airflow-core/src/airflow/models/taskinstance.py b/airflow-core/src/airflow/models/taskinstance.py index 6fb2b38696cf1..dd6e8b3bd3bf7 100644 --- a/airflow-core/src/airflow/models/taskinstance.py +++ b/airflow-core/src/airflow/models/taskinstance.py @@ -1428,7 +1428,6 @@ def register_asset_changes_in_db( outlet_events: list[dict[str, Any]], session: Session = NEW_SESSION, ) -> None: - print(task_outlets, outlet_events) from airflow.serialization.definitions.assets import ( SerializedAsset, SerializedAssetNameRef, @@ -1436,10 +1435,29 @@ def register_asset_changes_in_db( SerializedAssetUriRef, ) - # TODO: AIP-76 should we provide an interface to override this, so that the task can - # tell the truth if for some reason it touches a different partition? - # https://github.com/apache/airflow/issues/58474 - partition_key = ti.dag_run.partition_key + # Build per-event partition_keys lists from outlet events (non-alias events only). + # Each entry is a list of {"key": str, "extra": dict} dicts; empty list means no + # runtime partition key was set by the task. + asset_event_partition_keys_list: dict[SerializedAssetUniqueKey, list[dict]] = { + SerializedAssetUniqueKey(**event["dest_asset_key"]): event.get("partition_keys") or [] + for event in outlet_events + if "source_alias_name" not in event + } + + # Collect all unique partition key strings emitted by this task run. + all_runtime_pks: set[str] = { + pk_dict["key"] for pk_list in asset_event_partition_keys_list.values() for pk_dict in pk_list + } + + # When exactly one unique key is used across all outlets, treat this as a + # single-partition run and record it on the dag run (PartitionAtRuntime use case). + # For fan-out (multiple keys) we leave dag_run.partition_key untouched. + if len(all_runtime_pks) == 1: + ti.dag_run.partition_key = next(iter(all_runtime_pks)) + + # Fall back to the dag run's existing partition key (set by timetable or upstream). + dag_run_partition_key = ti.dag_run.partition_key + asset_keys = { SerializedAssetUniqueKey(o.name, o.uri) for o in task_outlets @@ -1466,12 +1484,40 @@ def register_asset_changes_in_db( ) } - asset_event_extras: dict[SerializedAssetUniqueKey, dict] = { + # Base (event-level) extras; merged with per-partition extras when fan-out is used. + asset_event_base_extras: dict[SerializedAssetUniqueKey, dict] = { SerializedAssetUniqueKey(**event["dest_asset_key"]): event["extra"] for event in outlet_events if "source_alias_name" not in event } + def _register_for_asset( + am: AssetModel, + base_extra: dict, + pk_list: list[dict], + ) -> None: + """Register one or more asset change events depending on partition_keys.""" + if pk_list: + # Fan-out: one event per partition key. + for pk_dict in pk_list: + merged_extra = {**base_extra, **pk_dict["extra"]} + asset_manager.register_asset_change( + task_instance=ti, + asset=am, + extra=merged_extra, + partition_key=pk_dict["key"], + session=session, + ) + else: + # No runtime partition keys — fall back to dag run's partition key. + asset_manager.register_asset_change( + task_instance=ti, + asset=am, + extra=base_extra, + partition_key=dag_run_partition_key, + session=session, + ) + for key in asset_keys: try: am = asset_models[key] @@ -1483,17 +1529,20 @@ def register_asset_changes_in_db( ) continue ti.log.debug("register event for asset %s", am) - asset_manager.register_asset_change( - task_instance=ti, - asset=am, - extra=asset_event_extras.get(key), - partition_key=partition_key, - session=session, + _register_for_asset( + am, + base_extra=asset_event_base_extras.get(key) or {}, + pk_list=asset_event_partition_keys_list.get(key, []), ) if asset_name_refs: asset_models_by_name = {key.name: am for key, am in asset_models.items()} - asset_event_extras_by_name = {key.name: extra for key, extra in asset_event_extras.items()} + asset_event_base_extras_by_name = { + key.name: extra for key, extra in asset_event_base_extras.items() + } + asset_event_pks_by_name = { + key.name: pk_list for key, pk_list in asset_event_partition_keys_list.items() + } for nref in asset_name_refs: try: am = asset_models_by_name[nref.name] @@ -1503,16 +1552,19 @@ def register_asset_changes_in_db( ) continue ti.log.debug("register event for asset name ref %s", am) - asset_manager.register_asset_change( - task_instance=ti, - asset=am, - extra=asset_event_extras_by_name.get(nref.name), - partition_key=partition_key, - session=session, + _register_for_asset( + am, + base_extra=asset_event_base_extras_by_name.get(nref.name) or {}, + pk_list=asset_event_pks_by_name.get(nref.name, []), ) if asset_uri_refs: asset_models_by_uri = {key.uri: am for key, am in asset_models.items()} - asset_event_extras_by_uri = {key.uri: extra for key, extra in asset_event_extras.items()} + asset_event_base_extras_by_uri = { + key.uri: extra for key, extra in asset_event_base_extras.items() + } + asset_event_pks_by_uri = { + key.uri: pk_list for key, pk_list in asset_event_partition_keys_list.items() + } for uref in asset_uri_refs: try: am = asset_models_by_uri[uref.uri] @@ -1522,12 +1574,10 @@ def register_asset_changes_in_db( ) continue ti.log.debug("register event for asset uri ref %s", am) - asset_manager.register_asset_change( - task_instance=ti, - asset=am, - extra=asset_event_extras_by_uri.get(uref.uri), - partition_key=partition_key, - session=session, + _register_for_asset( + am, + base_extra=asset_event_base_extras_by_uri.get(uref.uri) or {}, + pk_list=asset_event_pks_by_uri.get(uref.uri, []), ) def _asset_event_extras_from_aliases() -> dict[tuple[SerializedAssetUniqueKey, str, str], set[str]]: @@ -1567,7 +1617,7 @@ def _asset_event_extras_from_aliases() -> dict[tuple[SerializedAssetUniqueKey, s asset=asset, source_alias_names=event_aliase_names, extra=asset_event_extra, - partition_key=partition_key, + partition_key=dag_run_partition_key, session=session, ) if event is None: @@ -1579,7 +1629,7 @@ def _asset_event_extras_from_aliases() -> dict[tuple[SerializedAssetUniqueKey, s asset=asset, source_alias_names=event_aliase_names, extra=asset_event_extra, - partition_key=partition_key, + partition_key=dag_run_partition_key, session=session, ) diff --git a/airflow-core/src/airflow/serialization/encoders.py b/airflow-core/src/airflow/serialization/encoders.py index 2f30511a1e5fb..e201d7965df99 100644 --- a/airflow-core/src/airflow/serialization/encoders.py +++ b/airflow-core/src/airflow/serialization/encoders.py @@ -57,7 +57,12 @@ AssetTriggeredTimetable, PartitionedAssetTimetable, ) -from airflow.sdk.definitions.timetables.simple import ContinuousTimetable, NullTimetable, OnceTimetable +from airflow.sdk.definitions.timetables.simple import ( + ContinuousTimetable, + NullTimetable, + OnceTimetable, + PartitionAtRuntime, +) from airflow.sdk.definitions.timetables.trigger import CronPartitionTimetable from airflow.serialization.decoders import decode_deadline_alert from airflow.serialization.definitions.assets import ( @@ -292,6 +297,7 @@ class _Serializer: MultipleCronTriggerTimetable: "airflow.timetables.trigger.MultipleCronTriggerTimetable", NullTimetable: "airflow.timetables.simple.NullTimetable", OnceTimetable: "airflow.timetables.simple.OnceTimetable", + PartitionAtRuntime: "airflow.timetables.simple.PartitionAtRuntime", PartitionedAssetTimetable: "airflow.timetables.simple.PartitionedAssetTimetable", } @@ -317,7 +323,10 @@ def serialize_timetable(self, timetable: BaseTimetable | CoreTimetable) -> dict[ @serialize_timetable.register(ContinuousTimetable) @serialize_timetable.register(NullTimetable) @serialize_timetable.register(OnceTimetable) - def _(self, timetable: ContinuousTimetable | NullTimetable | OnceTimetable) -> dict[str, Any]: + @serialize_timetable.register(PartitionAtRuntime) + def _( + self, timetable: ContinuousTimetable | NullTimetable | OnceTimetable | PartitionAtRuntime + ) -> dict[str, Any]: return {} @serialize_timetable.register diff --git a/airflow-core/src/airflow/timetables/simple.py b/airflow-core/src/airflow/timetables/simple.py index 01fb12f81dd0c..3970304d9d68d 100644 --- a/airflow-core/src/airflow/timetables/simple.py +++ b/airflow-core/src/airflow/timetables/simple.py @@ -183,6 +183,23 @@ def next_dagrun_info( return DagRunInfo.interval(start, end) +class PartitionAtRuntime(NullTimetable): + """ + Timetable indicating that partition keys are determined at runtime. + + Semantically equivalent to ``NullTimetable`` (the Dag is externally + triggered and never auto-scheduled), but signals that tasks in this Dag + will set partition keys on their outlet events at execution time via + ``outlet_events[asset].partition_keys = [...].`` + """ + + description: str = "Never, partition key(s) set at runtime" + + @property + def summary(self) -> str: + return "PartitionAtRuntime" + + class AssetTriggeredTimetable(_TrivialTimetable): """ Timetable that never schedules anything. diff --git a/airflow-core/tests/unit/models/test_taskinstance.py b/airflow-core/tests/unit/models/test_taskinstance.py index d874b6163a039..328cf1275ac69 100644 --- a/airflow-core/tests/unit/models/test_taskinstance.py +++ b/airflow-core/tests/unit/models/test_taskinstance.py @@ -82,6 +82,7 @@ BaseSensorOperator, IdentityMapper, Metadata, + PartitionAtRuntime, PartitionedAssetTimetable, task, task_group, @@ -3444,6 +3445,95 @@ def test_when_dag_run_has_partition_and_downstreams_listening_then_tables_popula assert pakl.target_dag_id == "asset_event_listener" +@pytest.mark.db_test +def test_partition_keys_in_outlet_events_single_key_sets_dag_run(dag_maker, session): + """When a single partition key is emitted, dag_run.partition_key is updated.""" + asset = Asset(name="single_pk_asset") + with dag_maker(dag_id="producer", schedule=PartitionAtRuntime(), session=session) as dag: + EmptyOperator(task_id="t", outlets=[asset]) + dr = dag_maker.create_dagrun(session=session) + assert dr.partition_key is None + [ti] = dr.get_task_instances(session=session) + dest_key = {"name": asset.name, "uri": asset.name} + + TaskInstance.register_asset_changes_in_db( + ti=ti, + task_outlets=[ensure_serialized_asset(asset).asprofile()], + outlet_events=[ + { + "dest_asset_key": dest_key, + "extra": {"info": "test"}, + "partition_keys": [{"key": "region_a", "extra": {}}], + } + ], + session=session, + ) + assert dr.partition_key == "region_a" + actual_event = session.scalar(select(AssetEvent).where(AssetEvent.source_dag_id == dag.dag_id)) + assert actual_event.partition_key == "region_a" + assert actual_event.extra == {"info": "test"} + + +@pytest.mark.db_test +def test_partition_keys_in_outlet_events_fan_out_creates_multiple_events(dag_maker, session): + """When multiple partition keys are emitted, one asset event is created per key.""" + asset = Asset(name="fan_out_asset") + with dag_maker(dag_id="fan_out_producer", schedule=PartitionAtRuntime(), session=session) as dag: + EmptyOperator(task_id="t", outlets=[asset]) + dr = dag_maker.create_dagrun(session=session) + [ti] = dr.get_task_instances(session=session) + dest_key = {"name": asset.name, "uri": asset.name} + + regions = ["us", "eu", "apac"] + TaskInstance.register_asset_changes_in_db( + ti=ti, + task_outlets=[ensure_serialized_asset(asset).asprofile()], + outlet_events=[ + { + "dest_asset_key": dest_key, + "extra": {}, + "partition_keys": [ + {"key": region, "extra": {"source": f"s3://{region}"}} for region in regions + ], + } + ], + session=session, + ) + # Fan-out: dag_run.partition_key should NOT be set (multiple keys) + assert dr.partition_key is None + events = session.scalars(select(AssetEvent).where(AssetEvent.source_dag_id == dag.dag_id)).all() + assert len(events) == 3 + assert {e.partition_key for e in events} == {"us", "eu", "apac"} + assert {e.extra.get("source") for e in events} == {"s3://us", "s3://eu", "s3://apac"} + + +@pytest.mark.db_test +def test_partition_key_per_partition_extra_merged_with_base_extra(dag_maker, session): + """Per-partition extra is merged with the base event extra.""" + asset = Asset(name="merged_extra_asset") + with dag_maker(dag_id="merged_extra_producer", schedule=PartitionAtRuntime(), session=session) as dag: + EmptyOperator(task_id="t", outlets=[asset]) + dr = dag_maker.create_dagrun(session=session) + [ti] = dr.get_task_instances(session=session) + dest_key = {"name": asset.name, "uri": asset.name} + + TaskInstance.register_asset_changes_in_db( + ti=ti, + task_outlets=[ensure_serialized_asset(asset).asprofile()], + outlet_events=[ + { + "dest_asset_key": dest_key, + "extra": {"base": "value"}, + "partition_keys": [{"key": "k1", "extra": {"pk": "specific"}}], + } + ], + session=session, + ) + actual_event = session.scalar(select(AssetEvent).where(AssetEvent.source_dag_id == dag.dag_id)) + assert actual_event.partition_key == "k1" + assert actual_event.extra == {"base": "value", "pk": "specific"} + + async def empty_callback_for_deadline(): """Used in deadline tests to confirm that Deadlines and DeadlineAlerts function correctly.""" pass diff --git a/task-sdk/src/airflow/sdk/__init__.py b/task-sdk/src/airflow/sdk/__init__.py index 8d10d6aba491c..4f5556fbf7967 100644 --- a/task-sdk/src/airflow/sdk/__init__.py +++ b/task-sdk/src/airflow/sdk/__init__.py @@ -58,7 +58,9 @@ "ObjectStoragePath", "Param", "ParamsDict", + "PartitionAtRuntime", "PartitionedAssetTimetable", + "PartitionKey", "PartitionMapper", "PokeReturnValue", "ProductMapper", @@ -113,7 +115,14 @@ from airflow.sdk.bases.skipmixin import SkipMixin from airflow.sdk.bases.xcom import BaseXCom from airflow.sdk.configuration import AirflowSDKConfigParser - from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAll, AssetAny, AssetWatcher + from airflow.sdk.definitions.asset import ( + Asset, + AssetAlias, + AssetAll, + AssetAny, + AssetWatcher, + PartitionKey, + ) from airflow.sdk.definitions.asset.decorators import asset from airflow.sdk.definitions.asset.metadata import Metadata from airflow.sdk.definitions.callback import AsyncCallback, SyncCallback @@ -149,6 +158,7 @@ CronDataIntervalTimetable, DeltaDataIntervalTimetable, ) + from airflow.sdk.definitions.timetables.simple import PartitionAtRuntime from airflow.sdk.definitions.timetables.trigger import ( CronPartitionTimetable, CronTriggerTimetable, @@ -202,7 +212,9 @@ "ObjectStoragePath": ".io.path", "Param": ".definitions.param", "ParamsDict": ".definitions.param", + "PartitionAtRuntime": ".definitions.timetables.simple", "PartitionedAssetTimetable": ".definitions.timetables.assets", + "PartitionKey": ".definitions.asset", "PartitionMapper": ".definitions.partition_mappers.base", "PokeReturnValue": ".bases.sensor", "ProductMapper": ".definitions.partition_mappers.product", diff --git a/task-sdk/src/airflow/sdk/__init__.pyi b/task-sdk/src/airflow/sdk/__init__.pyi index 4f5177d85cb24..f052700de2a0d 100644 --- a/task-sdk/src/airflow/sdk/__init__.pyi +++ b/task-sdk/src/airflow/sdk/__init__.pyi @@ -48,6 +48,7 @@ from airflow.sdk.definitions.asset import ( AssetAll as AssetAll, AssetAny as AssetAny, AssetWatcher as AssetWatcher, + PartitionKey as PartitionKey, ) from airflow.sdk.definitions.asset.decorators import asset as asset from airflow.sdk.definitions.asset.metadata import Metadata as Metadata @@ -86,6 +87,7 @@ from airflow.sdk.definitions.timetables.interval import ( CronDataIntervalTimetable, DeltaDataIntervalTimetable, ) +from airflow.sdk.definitions.timetables.simple import PartitionAtRuntime as PartitionAtRuntime from airflow.sdk.definitions.timetables.trigger import ( CronPartitionTimetable, CronTriggerTimetable, @@ -137,7 +139,9 @@ __all__ = [ "ObjectStoragePath", "Param", "PokeReturnValue", + "PartitionAtRuntime", "PartitionedAssetTimetable", + "PartitionKey", "PartitionMapper", "ProductMapper", "SecretCache", diff --git a/task-sdk/src/airflow/sdk/definitions/asset/__init__.py b/task-sdk/src/airflow/sdk/definitions/asset/__init__.py index 51f1b14e3b027..61ad97ba51c8e 100644 --- a/task-sdk/src/airflow/sdk/definitions/asset/__init__.py +++ b/task-sdk/src/airflow/sdk/definitions/asset/__init__.py @@ -57,6 +57,7 @@ "AssetRef", "AssetUriRef", "AssetWatcher", + "PartitionKey", ] from airflow.sdk.configuration import conf @@ -513,3 +514,25 @@ class AssetAliasEvent(attrs.AttrsInstance): dest_asset_key: AssetUniqueKey dest_asset_extra: dict[str, JsonValue] extra: dict[str, JsonValue] + + +@attrs.define(frozen=True) +class PartitionKey: + """ + A partition key with optional per-partition metadata. + + Use :class:`PartitionKey` instead of a plain string when you need to attach + extra metadata to a specific partition event: + + .. code-block:: python + + outlet_events[my_asset].partition_keys = [ + PartitionKey(key="region_a", extra={"source": "s3://bucket/region_a"}), + PartitionKey(key="region_b", extra={"source": "s3://bucket/region_b"}), + ] + + Plain strings are also accepted and are equivalent to ``PartitionKey(key=..., extra={})``. + """ + + key: str + extra: dict[str, JsonValue] = attrs.Factory(dict) diff --git a/task-sdk/src/airflow/sdk/definitions/asset/decorators.py b/task-sdk/src/airflow/sdk/definitions/asset/decorators.py index 26205f0c58335..1973877150a6e 100644 --- a/task-sdk/src/airflow/sdk/definitions/asset/decorators.py +++ b/task-sdk/src/airflow/sdk/definitions/asset/decorators.py @@ -49,10 +49,36 @@ def _validate_asset_function_arguments(f: Callable) -> None: raise TypeError(f"positional-only argument '{name}' without a default is not supported in @asset") +class _AssetSelfProxy: + """ + Proxy for the ``self`` parameter in ``@asset`` functions. + + Allows setting ``partition_keys`` at runtime, which is then propagated + to the outlet event for the dag run. All other attribute reads are + forwarded to the underlying :class:`Asset`. + """ + + def __init__(self, asset: Asset) -> None: + object.__setattr__(self, "_asset", asset) + object.__setattr__(self, "partition_keys", []) + + def __getattr__(self, name: str) -> Any: + return getattr(object.__getattribute__(self, "_asset"), name) + + def __setattr__(self, name: str, value: Any) -> None: + if name == "partition_keys": + object.__setattr__(self, "partition_keys", value) + else: + raise AttributeError( + f"Cannot set '{name}' on @asset self; only 'partition_keys' is settable at runtime" + ) + + class _AssetMainOperator(PythonOperator): def __init__(self, *, definition_name: str, uri: str | None = None, **kwargs) -> None: super().__init__(**kwargs) self._definition_name = definition_name + self._self_proxy: _AssetSelfProxy | None = None @classmethod def from_definition(cls, definition: AssetDefinition | MultiAssetDefinition) -> Self: @@ -62,7 +88,8 @@ def from_definition(cls, definition: AssetDefinition | MultiAssetDefinition) -> inlets=[ Asset.ref(name=inlet_asset_name) for inlet_asset_name, param in inspect.signature(definition._function).parameters.items() - if inlet_asset_name not in ("self", "context") and param.default is inspect.Parameter.empty + if inlet_asset_name not in ("self", "context", "outlet_events") + and param.default is inspect.Parameter.empty ], outlets=list(definition.iter_outlets()), python_callable=definition._function, @@ -86,9 +113,13 @@ def _fetch_asset(name: str) -> Asset: if param.default is not inspect.Parameter.empty: value = param.default elif key == "self": - value = _fetch_asset(self._definition_name) + fetched = _fetch_asset(self._definition_name) + self._self_proxy = _AssetSelfProxy(fetched) + value = self._self_proxy elif key == "context": value = context + elif key == "outlet_events": + value = context["outlet_events"] else: value = _fetch_asset(key) yield key, value @@ -96,6 +127,12 @@ def _fetch_asset(name: str) -> Asset: def determine_kwargs(self, context: Mapping[str, Any]) -> Mapping[str, Any]: return dict(self._iter_kwargs(context)) + def execute(self, context: Mapping[str, Any]) -> Any: + result = super().execute(context) + if self._self_proxy is not None and self._self_proxy.partition_keys: + context["outlet_events"][self._self_proxy._asset].partition_keys = self._self_proxy.partition_keys + return result + def _instantiate_task(definition: AssetDefinition | MultiAssetDefinition) -> None: decorated_operator = cast("_TaskDecorator", definition._function) diff --git a/task-sdk/src/airflow/sdk/definitions/timetables/simple.py b/task-sdk/src/airflow/sdk/definitions/timetables/simple.py index 3387cb3c386fd..78ab46b6ac502 100644 --- a/task-sdk/src/airflow/sdk/definitions/timetables/simple.py +++ b/task-sdk/src/airflow/sdk/definitions/timetables/simple.py @@ -46,3 +46,32 @@ class ContinuousTimetable(BaseTimetable): """ active_runs_limit = 1 + + +class PartitionAtRuntime(BaseTimetable): + """ + Marker timetable indicating that a Dag's partition key is determined at runtime. + + Use ``schedule=PartitionAtRuntime()`` to signal that tasks in this Dag will + set partition keys on their outlet events at execution time, rather than + having partition keys derived from a cron expression or an upstream asset + event. + + Like ``schedule=None``, the Dag is not scheduled automatically — it must be + triggered externally (manually or via the API). The difference is semantic: + it tells readers and tooling that the Dag is expected to emit partitioned + asset events whose keys are discovered at runtime. + + .. code-block:: python + + with DAG("ingest_regions", schedule=PartitionAtRuntime()): + + @task(outlets=[region_stats]) + def discover_and_ingest(*, outlet_events): + regions = fetch_active_regions() + outlet_events[region_stats].partition_keys = regions + + discover_and_ingest() + """ + + can_be_scheduled = False diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py b/task-sdk/src/airflow/sdk/execution_time/context.py index db5a75e10c18d..50af6880bedd8 100644 --- a/task-sdk/src/airflow/sdk/execution_time/context.py +++ b/task-sdk/src/airflow/sdk/execution_time/context.py @@ -40,6 +40,7 @@ AssetUniqueKey, AssetUriRef, BaseAssetUniqueKey, + PartitionKey, ) from airflow.sdk.exceptions import AirflowNotFoundException, AirflowRuntimeError, ErrorType from airflow.sdk.log import mask_secret @@ -484,6 +485,35 @@ class OutletEventAccessor(_AssetRefResolutionMixin): key: BaseAssetUniqueKey extra: dict[str, JsonValue] = attrs.Factory(dict) asset_alias_events: list[AssetAliasEvent] = attrs.field(factory=list) + partition_keys: list[str | PartitionKey] = attrs.Factory(list) + + def add_partition( + self, + key: str | PartitionKey, + *, + extra: dict[str, JsonValue] | None = None, + ) -> None: + """ + Add a partition key to this outlet event. + + Equivalent to appending to :attr:`partition_keys` directly, but also + lets you attach per-partition ``extra`` metadata inline:: + + outlet_events[asset].add_partition("region_a", extra={"source": "s3://…"}) + outlet_events[asset].add_partition(PartitionKey(key="region_b", extra={…})) + + :param key: The partition key string or a :class:`~airflow.sdk.PartitionKey` + instance that already carries extra metadata. + :param extra: Per-partition extra metadata. When *key* is a plain string + this is set directly. When *key* is already a :class:`PartitionKey`, + the values are merged (``extra`` takes precedence). + """ + if isinstance(key, str): + self.partition_keys.append(PartitionKey(key=key, extra=extra or {})) + elif extra is not None: + self.partition_keys.append(PartitionKey(key=key.key, extra={**key.extra, **extra})) + else: + self.partition_keys.append(key) def add(self, asset: Asset | AssetRef, extra: dict[str, JsonValue] | None = None) -> None: """Add an AssetEvent to an existing Asset.""" diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 73e8fd0cbf62e..f9a9b9a1ddc1d 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -1102,11 +1102,22 @@ def _build_asset_profiles(lineage_objects: list) -> Iterator[AssetProfile]: def _serialize_outlet_events(events: OutletEventAccessorsProtocol) -> Iterator[dict[str, JsonValue]]: if TYPE_CHECKING: assert isinstance(events, OutletEventAccessors) + from airflow.sdk.definitions.asset import PartitionKey + # We just collect everything the user recorded in the accessors. # Further filtering will be done in the API server. for key, accessor in events._dict.items(): if isinstance(key, AssetUniqueKey): - yield {"dest_asset_key": attrs.asdict(key), "extra": accessor.extra} + yield { + "dest_asset_key": attrs.asdict(key), + "extra": accessor.extra, + "partition_keys": [ + {"key": pk.key, "extra": pk.extra} + if isinstance(pk, PartitionKey) + else {"key": pk, "extra": {}} + for pk in accessor.partition_keys + ], + } for alias_event in accessor.asset_alias_events: yield attrs.asdict(alias_event) diff --git a/task-sdk/src/airflow/sdk/types.py b/task-sdk/src/airflow/sdk/types.py index 4ce43c57c117b..486af9e9664c6 100644 --- a/task-sdk/src/airflow/sdk/types.py +++ b/task-sdk/src/airflow/sdk/types.py @@ -36,7 +36,14 @@ from airflow.sdk._shared.logging.types import Logger as Logger from airflow.sdk.api.datamodels._generated import PreviousTIResponse, TaskInstanceState from airflow.sdk.bases.operator import BaseOperator - from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasEvent, AssetRef, BaseAssetUniqueKey + from airflow.sdk.definitions.asset import ( + Asset, + AssetAlias, + AssetAliasEvent, + AssetRef, + BaseAssetUniqueKey, + PartitionKey, + ) from airflow.sdk.definitions.context import Context from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.sdk.execution_time.comms import DagResult @@ -204,6 +211,7 @@ class OutletEventAccessorProtocol(Protocol): key: BaseAssetUniqueKey extra: dict[str, JsonValue] asset_alias_events: list[AssetAliasEvent] + partition_keys: list[str | PartitionKey] def __init__( self, @@ -211,8 +219,15 @@ def __init__( key: BaseAssetUniqueKey, extra: dict[str, JsonValue], asset_alias_events: list[AssetAliasEvent], + partition_keys: list[str | PartitionKey] | None = None, ) -> None: ... def add(self, asset: Asset, extra: dict[str, JsonValue] | None = None) -> None: ... + def add_partition( + self, + key: str | PartitionKey, + *, + extra: dict[str, JsonValue] | None = None, + ) -> None: ... class OutletEventAccessorsProtocol(Protocol): diff --git a/task-sdk/tests/task_sdk/execution_time/test_context.py b/task-sdk/tests/task_sdk/execution_time/test_context.py index a2c3310822b5f..c6a3ea1eb1c4b 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_context.py +++ b/task-sdk/tests/task_sdk/execution_time/test_context.py @@ -31,6 +31,7 @@ AssetAliasEvent, AssetAliasUniqueKey, AssetUniqueKey, + PartitionKey, ) from airflow.sdk.definitions.connection import Connection from airflow.sdk.definitions.variable import Variable @@ -420,6 +421,58 @@ def test_add_with_db(self, add_args, key, asset_alias_events, mock_supervisor_co assert outlet_event_accessor.asset_alias_events == asset_alias_events +class TestOutletEventAccessorPartitionKeys: + """Tests for add_partition() and partition_keys list on OutletEventAccessor.""" + + def _make_accessor(self) -> OutletEventAccessor: + key = AssetUniqueKey.from_asset(Asset("test_asset")) + return OutletEventAccessor(key=key, extra={}) + + def test_add_partition_str(self): + accessor = self._make_accessor() + accessor.add_partition("region_a") + assert accessor.partition_keys == [PartitionKey(key="region_a", extra={})] + + def test_add_partition_str_with_extra(self): + accessor = self._make_accessor() + accessor.add_partition("region_a", extra={"source": "s3://bucket"}) + assert accessor.partition_keys == [PartitionKey(key="region_a", extra={"source": "s3://bucket"})] + + def test_add_partition_partition_key_object(self): + accessor = self._make_accessor() + pk = PartitionKey(key="region_b", extra={"x": 1}) + accessor.add_partition(pk) + assert accessor.partition_keys == [PartitionKey(key="region_b", extra={"x": 1})] + + def test_add_partition_partition_key_with_extra_override(self): + accessor = self._make_accessor() + pk = PartitionKey(key="region_c", extra={"a": 1}) + accessor.add_partition(pk, extra={"b": 2}) + assert accessor.partition_keys == [PartitionKey(key="region_c", extra={"a": 1, "b": 2})] + + def test_add_partition_multiple(self): + accessor = self._make_accessor() + accessor.add_partition("us") + accessor.add_partition("eu") + accessor.add_partition("apac") + assert len(accessor.partition_keys) == 3 + assert [pk.key for pk in accessor.partition_keys] == ["us", "eu", "apac"] # type: ignore[union-attr] + + def test_partition_keys_setter_plain_strings(self): + accessor = self._make_accessor() + accessor.partition_keys = ["key1", "key2"] + assert accessor.partition_keys == ["key1", "key2"] + + def test_partition_keys_setter_partition_key_objects(self): + accessor = self._make_accessor() + accessor.partition_keys = [PartitionKey(key="k1", extra={"a": 1})] + assert accessor.partition_keys == [PartitionKey(key="k1", extra={"a": 1})] + + def test_default_partition_keys_is_empty(self): + accessor = self._make_accessor() + assert accessor.partition_keys == [] + + class TestTriggeringAssetEventsAccessor: @pytest.fixture(autouse=True) def clear_cache(self):