Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,14 @@ def check_data_intervals(self):
return self

def validate_context(self, dag: SerializedDAG) -> dict:
if (
self.partition_key is not None
and not dag.timetable.partitioned
and not dag.timetable.partitioned_at_runtime
):
raise ValueError(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This rejection only runs on the public REST trigger path. The Execution API route (trigger_dag_run in execution_api/routes/dag_runs.py:144) passes partition_key straight into trigger_dag then create_dagrun with no partitionability check, so a non-partitioned run can still be created with a key there, and with the new inheritance it then emits partitioned-looking AssetEvents. TriggerDagRunOperator doesn't expose partition_key today so it isn't reachable through the shipped operator, but the wire field is forwarded ungated. Putting the check in create_dagrun/DagRun.__init__ would cover both entrypoints instead of just the request body.

f"Dag '{dag.dag_id}' is not a partitioned Dag and does not accept a partition_key."
)
coerced_logical_date = timezone.coerce_datetime(self.logical_date)
run_after = self.run_after or timezone.utcnow()
data_interval = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -620,9 +620,8 @@ def trigger_dag_run(
triggered_by = DagRunTriggeredByType.REST_API

dag = get_latest_version_of_dag(dag_bag, dag_id, session)
params = body.validate_context(dag)

try:
params = body.validate_context(dag)
dag_run = dag.create_dagrun(
run_id=params["run_id"],
logical_date=params["logical_date"],
Expand Down
16 changes: 4 additions & 12 deletions airflow-core/src/airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -1499,7 +1499,6 @@ def register_asset_changes_in_db(
)

payloads_by_asset: dict[SerializedAssetUniqueKey, list[OutletEventPayload]] = defaultdict(list)
runtime_pks: set[str] = set()
for outlet_event in outlet_events:
# Alias-emitted events are handled separately further down via
# register_asset_change_for_alias, which uses the DagRun-level
Expand All @@ -1513,16 +1512,6 @@ def register_asset_changes_in_db(
payloads_by_asset[asset_key].append(
OutletEventPayload(extra=outlet_event["extra"], partition_key=partition_key)
)
if partition_key is not None:
runtime_pks.add(partition_key)

# Back-fill DagRun.partition_key from the task emission when the task
# emitted exactly one distinct partition_key across all outlet events
# and the DagRun did not already have one set. This lets a task that
# discovers the partition at runtime (rather than via params) act as
# the source of truth for the DagRun-level key.
if len(runtime_pks) == 1 and ti.dag_run.partition_key is None:
ti.dag_run.partition_key = next(iter(runtime_pks))
dag_run_partition_key = ti.dag_run.partition_key

asset_keys = {
Expand Down Expand Up @@ -1563,11 +1552,14 @@ def _register(am: AssetModel, key: SerializedAssetUniqueKey) -> None:
)
return
for payload in payloads_for_asset:
effective_pk = (
payload.partition_key if payload.partition_key is not None else dag_run_partition_key
)
asset_manager.register_asset_change(
task_instance=ti,
asset=am,
extra=payload.extra,
partition_key=payload.partition_key,
partition_key=effective_pk,
session=session,
)

Expand Down
6 changes: 6 additions & 0 deletions airflow-core/src/airflow/timetables/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,12 @@ class PartitionAtRuntime(NullTimetable):
Timetable that never schedules anything; partition keys are set at runtime.

This corresponds to ``schedule=PartitionAtRuntime()``.

A run's ``partition_key`` (run-level provenance) must be supplied at trigger
time — for example via the REST API's ``partition_key`` field. Partition keys
discovered at task runtime populate the emitted :class:`~airflow.sdk.AssetEvent`
records but do **not** back-fill ``DagRun.partition_key`` after the run has
been created.
"""

description: str = "Never, partition key(s) set at runtime"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,14 @@
from fastapi.testclient import TestClient
from sqlalchemy import func, select, update

from airflow import plugins_manager
from airflow._shared.module_loading import qualname
from airflow._shared.timezones import timezone
from airflow.api_fastapi.auth.managers.models.resource_details import DagAccessEntity, DagDetails
from airflow.api_fastapi.auth.managers.simple.user import SimpleAuthManagerUser
from airflow.api_fastapi.core_api.datamodels.dag_versions import DagVersionResponse
from airflow.api_fastapi.core_api.services.public.common import resolve_run_on_latest_version
from airflow.exceptions import ParamValidationError
from airflow.models import DagModel, DagRun, Log
from airflow.models.asset import AssetEvent, AssetModel
from airflow.models.dagbundle import DagBundleModel
Expand All @@ -39,7 +43,10 @@
from airflow.providers.standard.operators.empty import EmptyOperator
from airflow.sdk.definitions.asset import Asset
from airflow.sdk.definitions.param import Param
from airflow.settings import _configure_async_session
from airflow.timetables.interval import CronDataIntervalTimetable
from airflow.timetables.simple import PartitionAtRuntime
from airflow.timetables.trigger import CronPartitionTimetable
from airflow.utils.session import provide_session
from airflow.utils.state import DagRunState, State
from airflow.utils.types import DagRunTriggeredByType, DagRunType
Expand All @@ -56,6 +63,7 @@
clear_db_serialized_dags,
)
from tests_common.test_utils.format_datetime import from_datetime_to_zulu, from_datetime_to_zulu_without_ms
from unit.listeners.class_listener import ClassBasedListener

if TYPE_CHECKING:
from airflow.models.dag_version import DagVersion
Expand Down Expand Up @@ -83,9 +91,6 @@ def generate_run_id(
@pytest.fixture
def custom_timetable_plugin(monkeypatch):
"""Fixture to register CustomTimetable for serialization."""
from airflow import plugins_manager
from airflow._shared.module_loading import qualname

timetable_class_name = qualname(CustomTimetable)
existing_timetables = getattr(plugins_manager, "timetable_classes", None) or {}

Expand Down Expand Up @@ -1513,8 +1518,6 @@ def test_patch_dag_run_bad_request(self, test_client):
)
@pytest.mark.usefixtures("configure_git_connection_for_dag_bundle")
def test_patch_dag_run_notifies_listeners(self, test_client, state, listener_state, listener_manager):
from unit.listeners.class_listener import ClassBasedListener

listener = ClassBasedListener()
listener_manager(listener)
response = test_client.patch(f"/dags/{DAG1_ID}/dagRuns/{DAG1_RUN1_ID}", json={"state": state})
Expand Down Expand Up @@ -2213,8 +2216,6 @@ def test_post_dag_runs_with_empty_payload(self, test_client):

@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.create_dagrun")
def test_dagrun_creation_param_validation_error_returns_400(self, mock_create_dagrun, test_client):
from airflow.exceptions import ParamValidationError

now = timezone.utcnow().isoformat()
error_message = "Invalid input for param x"
mock_create_dagrun.side_effect = ParamValidationError(error_message)
Expand Down Expand Up @@ -2466,6 +2467,74 @@ def test_custom_timetable_generate_run_id_for_manual_trigger(self, dag_maker, te
run = session.scalars(select(DagRun).where(DagRun.run_id == run_id_without_logical_date)).one()
assert run.dag_id == custom_dag_id

def test_should_respond_400_when_partition_key_given_for_non_partitioned_dag(self, test_client):
"""Passing partition_key to a non-partitioned Dag via REST trigger must return 400, not 500.

The validation happens in TriggerDAGRunPostBody.validate_context(), which is now called
inside the try/except block that converts ValueError to HTTP 400.
"""
now = timezone.utcnow().isoformat()
response = test_client.post(
f"/dags/{DAG1_ID}/dagRuns",
json={"logical_date": now, "partition_key": "some-partition"},
)
assert response.status_code == 400
assert "not a partitioned Dag" in response.json()["detail"]

@pytest.mark.usefixtures("configure_git_connection_for_dag_bundle")
def test_should_respond_200_when_partition_key_given_for_partitioned_dag(
self, dag_maker, test_client, session
):
"""partition_key on a genuinely partitioned Dag must not be rejected (happy-path guard).

Uses CronPartitionTimetable (partitioned=True) to confirm the reject path does not
fire for legitimate partitioned Dags.
"""
partitioned_dag_id = "test_partitioned_dag_trigger"
with dag_maker(
dag_id=partitioned_dag_id,
schedule=CronPartitionTimetable("0 * * * *", timezone="UTC"),
start_date=START_DATE1,
session=session,
serialized=True,
):
EmptyOperator(task_id="task")

session.commit()

response = test_client.post(
f"/dags/{partitioned_dag_id}/dagRuns",
json={"logical_date": None, "partition_key": "2025-01-01T00:00:00"},
)
assert response.status_code == 200

@pytest.mark.usefixtures("configure_git_connection_for_dag_bundle")
def test_should_respond_200_when_partition_key_given_for_partition_at_runtime_dag(
self, dag_maker, test_client, session
):
"""partition_key on a PartitionAtRuntime Dag must also be accepted (deferred validation).

partitioned_at_runtime=True means the Dag accepts runtime-discovered partition keys, so
the REST layer must not reject it even though timetable.partitioned is False.
"""
runtime_dag_id = "test_partition_at_runtime_dag_trigger"
with dag_maker(
dag_id=runtime_dag_id,
schedule=PartitionAtRuntime(),
start_date=START_DATE1,
session=session,
serialized=True,
):
EmptyOperator(task_id="task")

session.commit()

response = test_client.post(
f"/dags/{runtime_dag_id}/dagRuns",
json={"logical_date": None, "partition_key": "runtime-key"},
)
assert response.status_code == 200


class TestResolveRunOnLatestVersion:
@pytest.mark.parametrize("explicit_value", [True, False])
Expand Down Expand Up @@ -2556,8 +2625,6 @@ class TestWaitDagRun:
# test at least makes the tests run correctly.
@pytest.fixture(autouse=True)
def reconfigure_async_db_engine(self):
from airflow.settings import _configure_async_session

_configure_async_session()

def test_should_respond_401(self, unauthenticated_test_client):
Expand Down Expand Up @@ -2601,8 +2668,6 @@ def test_collect_task(self, test_client):
assert data == {"state": DagRunState.SUCCESS, "results": {"task_1": '"result_1"'}}

def test_should_respond_403_when_user_lacks_xcom_permission(self, test_client):
from airflow.api_fastapi.auth.managers.models.resource_details import DagAccessEntity, DagDetails

with mock.patch(
"airflow.api_fastapi.core_api.routes.public.dag_run.get_auth_manager",
autospec=True,
Expand Down
82 changes: 74 additions & 8 deletions airflow-core/tests/unit/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@
from airflow.ti_deps.deps.base_ti_dep import TIDepStatus
from airflow.ti_deps.deps.ready_to_reschedule import ReadyToRescheduleDep
from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep, _UpstreamTIStates
from airflow.timetables.simple import PartitionAtRuntime
from airflow.utils.session import create_session, provide_session
from airflow.utils.span_status import SpanStatus
from airflow.utils.state import DagRunState, State, TaskInstanceState
Expand Down Expand Up @@ -3513,8 +3514,13 @@ def test_when_dag_run_has_partition_and_downstreams_listening_then_tables_popula
assert pakl.target_dag_id == "asset_event_listener"


def test_runtime_partition_key_backfills_dag_run_when_none(dag_maker, session):
"""Single runtime key on a PartitionAtRuntime-style run (dag_run.partition_key=None) back-fills the run."""
def test_runtime_partition_key_does_not_backfill_dag_run_when_none(dag_maker, session):
"""Task-emitted partition_key lands on the AssetEvent but does NOT back-fill DagRun.partition_key.

DagRun.partition_key (provenance) is set by the scheduler / trigger side, not by task
runtime emission. A run that started with partition_key=None should remain None even when
a task emits an outlet event carrying its own key.
"""
asset = Asset(name="hello")
with dag_maker(dag_id="rt_pk_backfill", schedule=None) as dag:
EmptyOperator(task_id="hi", outlets=[asset])
Expand All @@ -3533,7 +3539,7 @@ def test_runtime_partition_key_backfills_dag_run_when_none(dag_maker, session):
event = session.scalar(select(AssetEvent).where(AssetEvent.source_dag_id == dag.dag_id))
assert event.partition_key == "us"
session.refresh(dr)
assert dr.partition_key == "us"
assert dr.partition_key is None


def test_runtime_partition_key_does_not_overwrite_scheduler_partition(dag_maker, session):
Expand Down Expand Up @@ -3582,8 +3588,13 @@ def test_runtime_partition_keys_fan_out_to_one_event_per_key(dag_maker, session)
assert dr.partition_key is None


def test_runtime_partition_key_is_none_when_event_has_no_key(dag_maker, session):
"""An outlet event without partition_key produces an AssetEvent with partition_key=None."""
def test_runtime_partition_key_inherits_dag_run_key_when_event_has_no_key(dag_maker, session):
"""An outlet event without partition_key inherits DagRun.partition_key as the routing pointer.

When a task emits an outlet event that carries no explicit partition_key, the resulting
AssetEvent should inherit the DagRun's partition_key so that downstream partitioned consumers
can still be routed correctly.
"""
asset = Asset(name="hello")
with dag_maker(dag_id="rt_pk_none", schedule=None) as dag:
EmptyOperator(task_id="hi", outlets=[asset])
Expand All @@ -3599,11 +3610,11 @@ def test_runtime_partition_key_is_none_when_event_has_no_key(dag_maker, session)
session=session,
)
event = session.scalar(select(AssetEvent).where(AssetEvent.source_dag_id == dag.dag_id))
assert event.partition_key is None
assert event.partition_key == "from-run"


def test_runtime_partition_key_mixed_events_for_same_asset(dag_maker, session):
"""One event with partition_key + one without produce two AssetEvents (key + None)."""
"""One event with an explicit key + one without produce two AssetEvents (explicit key + inherited run key)."""
asset = Asset(name="hello")
with dag_maker(dag_id="rt_pk_mixed", schedule=None) as dag:
EmptyOperator(task_id="hi", outlets=[asset])
Expand All @@ -3620,11 +3631,66 @@ def test_runtime_partition_key_mixed_events_for_same_asset(dag_maker, session):
session=session,
)
events = session.scalars(select(AssetEvent).where(AssetEvent.source_dag_id == dag.dag_id)).all()
assert {e.partition_key for e in events} == {"us", None}
assert {e.partition_key for e in events} == {"us", "from-run"}
session.refresh(dr)
assert dr.partition_key == "from-run"


def test_runtime_partition_key_event_stays_none_when_no_key_and_no_run_key(dag_maker, session):
"""Task has no partition_key and run has no partition_key -> event.partition_key is None.

Pins the `is not None` guard: both sides are None so effective_pk stays None and no
routing pointer is written to the AssetEvent.
"""
asset = Asset(name="hello")
with dag_maker(dag_id="rt_pk_both_none", schedule=None) as dag:
EmptyOperator(task_id="hi", outlets=[asset])
dr = dag_maker.create_dagrun(session=session)
assert dr.partition_key is None
[ti] = dr.get_task_instances(session=session)

TaskInstance.register_asset_changes_in_db(
ti=ti,
task_outlets=[ensure_serialized_asset(asset).asprofile()],
outlet_events=[
{"dest_asset_key": {"name": "hello", "uri": "hello"}, "extra": {}},
],
session=session,
)
event = session.scalar(select(AssetEvent).where(AssetEvent.source_dag_id == dag.dag_id))
assert event.partition_key is None


def test_runtime_partition_key_does_not_backfill_partition_at_runtime_run(dag_maker, session):
"""Task-emitted key lands on AssetEvent but does NOT back-fill DagRun.partition_key on a PartitionAtRuntime run.

Provenance contract: DagRun.partition_key is set only at run-creation/trigger time.
A PartitionAtRuntime Dag triggered via REST without a partition_key starts with
partition_key=None. Even when a task discovers a partition at runtime and emits an
outlet event carrying an explicit key, DagRun.partition_key must remain None — the
key belongs to the AssetEvent, not to the run's provenance.
"""
asset = Asset(name="hello")
with dag_maker(dag_id="rt_pk_par_backfill", schedule=PartitionAtRuntime()) as dag:
EmptyOperator(task_id="hi", outlets=[asset])
dr = dag_maker.create_dagrun(session=session)
assert dr.partition_key is None
[ti] = dr.get_task_instances(session=session)

TaskInstance.register_asset_changes_in_db(
ti=ti,
task_outlets=[ensure_serialized_asset(asset).asprofile()],
outlet_events=[
{"dest_asset_key": {"name": "hello", "uri": "hello"}, "extra": {}, "partition_key": "2025-01-01"},
],
session=session,
)
event = session.scalar(select(AssetEvent).where(AssetEvent.source_dag_id == dag.dag_id))
assert event.partition_key == "2025-01-01"
session.refresh(dr)
assert dr.partition_key is None


def test_when_runtime_partition_keys_and_downstreams_listening_then_tables_populated(
dag_maker,
session,
Expand Down
Loading