Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Example Dags demonstrating runtime partition key assignment.

These Dags showcase :class:`~airflow.sdk.PartitionAtRuntime` and the new
``partition_keys`` API on outlet events, where partition keys are discovered
and emitted at task execution time rather than being derived from a schedule.

Three patterns are shown:

1. **Single runtime key** — a task resolves which partition to process and sets
exactly one key, which is also recorded on the Dag run.

2. **Fan-out** — a task discovers a dynamic set of partitions at runtime and
emits one asset event per partition, triggering a separate downstream Dag run
for each.

3. **Fan-out with per-partition metadata** — same as fan-out but each
:class:`~airflow.sdk.PartitionKey` also carries extra metadata that is
merged into the corresponding asset event.
"""

from __future__ import annotations

from airflow.sdk import (
DAG,
AllowedKeyMapper,
Asset,
PartitionAtRuntime,
PartitionedAssetTimetable,
PartitionKey,
asset,
task,
)

# ---------------------------------------------------------------------------
# Pattern 1: single runtime partition key
# ---------------------------------------------------------------------------

single_key_asset = Asset(
uri="file://incoming/daily-report.csv",
name="daily_report",
)


with DAG(
dag_id="produce_daily_report",
schedule=PartitionAtRuntime(),
tags=["runtime-partition", "single-key"],
):
"""
Produce a daily report whose partition date is resolved at runtime.

``schedule=PartitionAtRuntime()`` signals that this Dag is externally
triggered and that a task will set the partition key during execution.
Setting exactly one key also records it on the Dag run, so downstream
:class:`~airflow.sdk.PartitionedAssetTimetable` Dags receive the correct
partition context.
"""

@task(outlets=[single_key_asset])
def generate_report(*, outlet_events):
"""Resolve today's report partition and emit a single asset event."""
# In practice, compute the partition from business logic or
# an upstream API response.
report_date = "2024-06-15"
outlet_events[single_key_asset].partition_keys = [report_date]

generate_report()


@asset(
uri="file://analytics/daily-report-summary.csv",
name="daily_report_summary",
schedule=PartitionedAssetTimetable(assets=single_key_asset),
tags=["runtime-partition", "single-key"],
)
def summarise_daily_report(self):
"""
Summarise a daily report partition.

Triggered by ``daily_report`` with the partition key set at runtime by
``produce_daily_report``.
"""
pass


# ---------------------------------------------------------------------------
# Pattern 2: fan-out — one event per discovered partition
# ---------------------------------------------------------------------------

region_snapshots = Asset(
uri="file://incoming/snapshots/regions.csv",
name="region_snapshots",
)


with DAG(
dag_id="discover_and_snapshot_regions",
schedule=PartitionAtRuntime(),
tags=["runtime-partition", "fan-out"],
):
"""
Discover active regions at runtime and snapshot each one.

The task queries an external source to find which regions are active,
then sets ``partition_keys`` to the full list. The scheduler creates one
downstream Dag run per key via the fan-out mechanism.
"""

@task(outlets=[region_snapshots])
def snapshot_regions(*, outlet_events):
"""Emit one asset event per discovered region."""
# Regions could come from an API call, a database query, etc.
active_regions = ["us-east", "eu-west", "ap-south"]
outlet_events[region_snapshots].partition_keys = active_regions

snapshot_regions()


@asset(
uri="file://analytics/region-aggregation.csv",
name="region_aggregation",
schedule=PartitionedAssetTimetable(
assets=region_snapshots,
default_partition_mapper=AllowedKeyMapper(["us-east", "eu-west", "ap-south"]),
),
tags=["runtime-partition", "fan-out"],
)
def aggregate_region(self):
"""
Aggregate data for a single region partition.

One run of this asset is triggered for each region key emitted by
``discover_and_snapshot_regions``.
"""
pass


# ---------------------------------------------------------------------------
# Pattern 3: fan-out with per-partition metadata via PartitionKey
# ---------------------------------------------------------------------------

raw_feed = Asset(
uri="file://incoming/feeds/raw.csv",
name="raw_feed",
)


with DAG(
dag_id="ingest_feeds_with_metadata",
schedule=PartitionAtRuntime(),
tags=["runtime-partition", "fan-out", "partition-key-extra"],
):
"""
Ingest data feeds discovered at runtime, attaching source metadata to each.

:class:`~airflow.sdk.PartitionKey` objects let you carry per-partition
``extra`` data alongside the key. That metadata is merged into the asset
event and made available to downstream tasks via ``inlet_events``.
"""

@task(outlets=[raw_feed])
def ingest_feeds(*, outlet_events):
"""Discover feeds and emit one partitioned event per feed with source metadata."""
feeds = [
("feed_a", "s3://bucket/feed_a/latest.csv"),
("feed_b", "s3://bucket/feed_b/latest.csv"),
("feed_c", "s3://bucket/feed_c/latest.csv"),
]
outlet_events[raw_feed].partition_keys = [
PartitionKey(key=feed_id, extra={"source_uri": uri}) for feed_id, uri in feeds
]

ingest_feeds()


@asset(
uri="file://analytics/feeds/processed.csv",
name="processed_feed",
schedule=PartitionedAssetTimetable(assets=raw_feed),
tags=["runtime-partition", "fan-out", "partition-key-extra"],
)
def process_feed(self):
"""
Process a single feed partition.

One run of this asset is triggered per feed key. The ``source_uri``
attached via :class:`~airflow.sdk.PartitionKey` is available on the
triggering asset event in ``inlet_events``.
"""
pass
106 changes: 78 additions & 28 deletions airflow-core/src/airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -1428,18 +1428,36 @@ def register_asset_changes_in_db(
outlet_events: list[dict[str, Any]],
session: Session = NEW_SESSION,
) -> None:
print(task_outlets, outlet_events)
from airflow.serialization.definitions.assets import (
SerializedAsset,
SerializedAssetNameRef,
SerializedAssetUniqueKey,
SerializedAssetUriRef,
)

# TODO: AIP-76 should we provide an interface to override this, so that the task can
# tell the truth if for some reason it touches a different partition?
# https://github.com/apache/airflow/issues/58474
partition_key = ti.dag_run.partition_key
# Build per-event partition_keys lists from outlet events (non-alias events only).
# Each entry is a list of {"key": str, "extra": dict} dicts; empty list means no
# runtime partition key was set by the task.
asset_event_partition_keys_list: dict[SerializedAssetUniqueKey, list[dict]] = {
SerializedAssetUniqueKey(**event["dest_asset_key"]): event.get("partition_keys") or []
for event in outlet_events
if "source_alias_name" not in event
}

# Collect all unique partition key strings emitted by this task run.
all_runtime_pks: set[str] = {
pk_dict["key"] for pk_list in asset_event_partition_keys_list.values() for pk_dict in pk_list
}

# When exactly one unique key is used across all outlets, treat this as a
# single-partition run and record it on the dag run (PartitionAtRuntime use case).
# For fan-out (multiple keys) we leave dag_run.partition_key untouched.
if len(all_runtime_pks) == 1:
ti.dag_run.partition_key = next(iter(all_runtime_pks))

# Fall back to the dag run's existing partition key (set by timetable or upstream).
dag_run_partition_key = ti.dag_run.partition_key

asset_keys = {
SerializedAssetUniqueKey(o.name, o.uri)
for o in task_outlets
Expand All @@ -1466,12 +1484,40 @@ def register_asset_changes_in_db(
)
}

asset_event_extras: dict[SerializedAssetUniqueKey, dict] = {
# Base (event-level) extras; merged with per-partition extras when fan-out is used.
asset_event_base_extras: dict[SerializedAssetUniqueKey, dict] = {
SerializedAssetUniqueKey(**event["dest_asset_key"]): event["extra"]
for event in outlet_events
if "source_alias_name" not in event
}

def _register_for_asset(
am: AssetModel,
base_extra: dict,
pk_list: list[dict],
) -> None:
"""Register one or more asset change events depending on partition_keys."""
if pk_list:
# Fan-out: one event per partition key.
for pk_dict in pk_list:
merged_extra = {**base_extra, **pk_dict["extra"]}
asset_manager.register_asset_change(
task_instance=ti,
asset=am,
extra=merged_extra,
partition_key=pk_dict["key"],
session=session,
)
else:
# No runtime partition keys — fall back to dag run's partition key.
asset_manager.register_asset_change(
task_instance=ti,
asset=am,
extra=base_extra,
partition_key=dag_run_partition_key,
session=session,
)

for key in asset_keys:
try:
am = asset_models[key]
Expand All @@ -1483,17 +1529,20 @@ def register_asset_changes_in_db(
)
continue
ti.log.debug("register event for asset %s", am)
asset_manager.register_asset_change(
task_instance=ti,
asset=am,
extra=asset_event_extras.get(key),
partition_key=partition_key,
session=session,
_register_for_asset(
am,
base_extra=asset_event_base_extras.get(key) or {},
pk_list=asset_event_partition_keys_list.get(key, []),
)

if asset_name_refs:
asset_models_by_name = {key.name: am for key, am in asset_models.items()}
asset_event_extras_by_name = {key.name: extra for key, extra in asset_event_extras.items()}
asset_event_base_extras_by_name = {
key.name: extra for key, extra in asset_event_base_extras.items()
}
asset_event_pks_by_name = {
key.name: pk_list for key, pk_list in asset_event_partition_keys_list.items()
}
for nref in asset_name_refs:
try:
am = asset_models_by_name[nref.name]
Expand All @@ -1503,16 +1552,19 @@ def register_asset_changes_in_db(
)
continue
ti.log.debug("register event for asset name ref %s", am)
asset_manager.register_asset_change(
task_instance=ti,
asset=am,
extra=asset_event_extras_by_name.get(nref.name),
partition_key=partition_key,
session=session,
_register_for_asset(
am,
base_extra=asset_event_base_extras_by_name.get(nref.name) or {},
pk_list=asset_event_pks_by_name.get(nref.name, []),
)
if asset_uri_refs:
asset_models_by_uri = {key.uri: am for key, am in asset_models.items()}
asset_event_extras_by_uri = {key.uri: extra for key, extra in asset_event_extras.items()}
asset_event_base_extras_by_uri = {
key.uri: extra for key, extra in asset_event_base_extras.items()
}
asset_event_pks_by_uri = {
key.uri: pk_list for key, pk_list in asset_event_partition_keys_list.items()
}
for uref in asset_uri_refs:
try:
am = asset_models_by_uri[uref.uri]
Expand All @@ -1522,12 +1574,10 @@ def register_asset_changes_in_db(
)
continue
ti.log.debug("register event for asset uri ref %s", am)
asset_manager.register_asset_change(
task_instance=ti,
asset=am,
extra=asset_event_extras_by_uri.get(uref.uri),
partition_key=partition_key,
session=session,
_register_for_asset(
am,
base_extra=asset_event_base_extras_by_uri.get(uref.uri) or {},
pk_list=asset_event_pks_by_uri.get(uref.uri, []),
)

def _asset_event_extras_from_aliases() -> dict[tuple[SerializedAssetUniqueKey, str, str], set[str]]:
Expand Down Expand Up @@ -1567,7 +1617,7 @@ def _asset_event_extras_from_aliases() -> dict[tuple[SerializedAssetUniqueKey, s
asset=asset,
source_alias_names=event_aliase_names,
extra=asset_event_extra,
partition_key=partition_key,
partition_key=dag_run_partition_key,
session=session,
)
if event is None:
Expand All @@ -1579,7 +1629,7 @@ def _asset_event_extras_from_aliases() -> dict[tuple[SerializedAssetUniqueKey, s
asset=asset,
source_alias_names=event_aliase_names,
extra=asset_event_extra,
partition_key=partition_key,
partition_key=dag_run_partition_key,
session=session,
)

Expand Down
Loading
Loading