diff --git a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/dag_run.py b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/dag_run.py index 97b37fdaae5ea..5aae6c164955b 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/dag_run.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/dag_run.py @@ -161,6 +161,14 @@ def check_data_intervals(self): return self def validate_context(self, dag: SerializedDAG) -> dict: + if ( + self.partition_key is not None + and not dag.timetable.partitioned + and not dag.timetable.partitioned_at_runtime + ): + raise ValueError( + f"Dag '{dag.dag_id}' is not a partitioned Dag and does not accept a partition_key." + ) coerced_logical_date = timezone.coerce_datetime(self.logical_date) run_after = self.run_after or timezone.utcnow() data_interval = None diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_run.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_run.py index a5c76550be559..095009ed514ee 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_run.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_run.py @@ -620,9 +620,8 @@ def trigger_dag_run( triggered_by = DagRunTriggeredByType.REST_API dag = get_latest_version_of_dag(dag_bag, dag_id, session) - params = body.validate_context(dag) - try: + params = body.validate_context(dag) dag_run = dag.create_dagrun( run_id=params["run_id"], logical_date=params["logical_date"], diff --git a/airflow-core/src/airflow/models/taskinstance.py b/airflow-core/src/airflow/models/taskinstance.py index 9cc8d6a627140..d1ce97c5ac039 100644 --- a/airflow-core/src/airflow/models/taskinstance.py +++ b/airflow-core/src/airflow/models/taskinstance.py @@ -1499,7 +1499,6 @@ def register_asset_changes_in_db( ) payloads_by_asset: dict[SerializedAssetUniqueKey, list[OutletEventPayload]] = defaultdict(list) - runtime_pks: set[str] = set() for outlet_event in outlet_events: # Alias-emitted events are handled separately further down via # register_asset_change_for_alias, which uses the DagRun-level @@ -1513,16 +1512,6 @@ def register_asset_changes_in_db( payloads_by_asset[asset_key].append( OutletEventPayload(extra=outlet_event["extra"], partition_key=partition_key) ) - if partition_key is not None: - runtime_pks.add(partition_key) - - # Back-fill DagRun.partition_key from the task emission when the task - # emitted exactly one distinct partition_key across all outlet events - # and the DagRun did not already have one set. This lets a task that - # discovers the partition at runtime (rather than via params) act as - # the source of truth for the DagRun-level key. - if len(runtime_pks) == 1 and ti.dag_run.partition_key is None: - ti.dag_run.partition_key = next(iter(runtime_pks)) dag_run_partition_key = ti.dag_run.partition_key asset_keys = { @@ -1563,11 +1552,14 @@ def _register(am: AssetModel, key: SerializedAssetUniqueKey) -> None: ) return for payload in payloads_for_asset: + effective_pk = ( + payload.partition_key if payload.partition_key is not None else dag_run_partition_key + ) asset_manager.register_asset_change( task_instance=ti, asset=am, extra=payload.extra, - partition_key=payload.partition_key, + partition_key=effective_pk, session=session, ) diff --git a/airflow-core/src/airflow/timetables/simple.py b/airflow-core/src/airflow/timetables/simple.py index 086e1153d618a..db5e4ec8d4dd4 100644 --- a/airflow-core/src/airflow/timetables/simple.py +++ b/airflow-core/src/airflow/timetables/simple.py @@ -189,6 +189,12 @@ class PartitionAtRuntime(NullTimetable): Timetable that never schedules anything; partition keys are set at runtime. This corresponds to ``schedule=PartitionAtRuntime()``. + + A run's ``partition_key`` (run-level provenance) must be supplied at trigger + time — for example via the REST API's ``partition_key`` field. Partition keys + discovered at task runtime populate the emitted :class:`~airflow.sdk.AssetEvent` + records but do **not** back-fill ``DagRun.partition_key`` after the run has + been created. """ description: str = "Never, partition key(s) set at runtime" diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_run.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_run.py index bc59d204378fa..845009aa978a6 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_run.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_run.py @@ -27,10 +27,14 @@ from fastapi.testclient import TestClient from sqlalchemy import func, select, update +from airflow import plugins_manager +from airflow._shared.module_loading import qualname from airflow._shared.timezones import timezone +from airflow.api_fastapi.auth.managers.models.resource_details import DagAccessEntity, DagDetails from airflow.api_fastapi.auth.managers.simple.user import SimpleAuthManagerUser from airflow.api_fastapi.core_api.datamodels.dag_versions import DagVersionResponse from airflow.api_fastapi.core_api.services.public.common import resolve_run_on_latest_version +from airflow.exceptions import ParamValidationError from airflow.models import DagModel, DagRun, Log from airflow.models.asset import AssetEvent, AssetModel from airflow.models.dagbundle import DagBundleModel @@ -39,7 +43,10 @@ from airflow.providers.standard.operators.empty import EmptyOperator from airflow.sdk.definitions.asset import Asset from airflow.sdk.definitions.param import Param +from airflow.settings import _configure_async_session from airflow.timetables.interval import CronDataIntervalTimetable +from airflow.timetables.simple import PartitionAtRuntime +from airflow.timetables.trigger import CronPartitionTimetable from airflow.utils.session import provide_session from airflow.utils.state import DagRunState, State from airflow.utils.types import DagRunTriggeredByType, DagRunType @@ -56,6 +63,7 @@ clear_db_serialized_dags, ) from tests_common.test_utils.format_datetime import from_datetime_to_zulu, from_datetime_to_zulu_without_ms +from unit.listeners.class_listener import ClassBasedListener if TYPE_CHECKING: from airflow.models.dag_version import DagVersion @@ -83,9 +91,6 @@ def generate_run_id( @pytest.fixture def custom_timetable_plugin(monkeypatch): """Fixture to register CustomTimetable for serialization.""" - from airflow import plugins_manager - from airflow._shared.module_loading import qualname - timetable_class_name = qualname(CustomTimetable) existing_timetables = getattr(plugins_manager, "timetable_classes", None) or {} @@ -1513,8 +1518,6 @@ def test_patch_dag_run_bad_request(self, test_client): ) @pytest.mark.usefixtures("configure_git_connection_for_dag_bundle") def test_patch_dag_run_notifies_listeners(self, test_client, state, listener_state, listener_manager): - from unit.listeners.class_listener import ClassBasedListener - listener = ClassBasedListener() listener_manager(listener) response = test_client.patch(f"/dags/{DAG1_ID}/dagRuns/{DAG1_RUN1_ID}", json={"state": state}) @@ -2213,8 +2216,6 @@ def test_post_dag_runs_with_empty_payload(self, test_client): @mock.patch("airflow.serialization.definitions.dag.SerializedDAG.create_dagrun") def test_dagrun_creation_param_validation_error_returns_400(self, mock_create_dagrun, test_client): - from airflow.exceptions import ParamValidationError - now = timezone.utcnow().isoformat() error_message = "Invalid input for param x" mock_create_dagrun.side_effect = ParamValidationError(error_message) @@ -2466,6 +2467,74 @@ def test_custom_timetable_generate_run_id_for_manual_trigger(self, dag_maker, te run = session.scalars(select(DagRun).where(DagRun.run_id == run_id_without_logical_date)).one() assert run.dag_id == custom_dag_id + def test_should_respond_400_when_partition_key_given_for_non_partitioned_dag(self, test_client): + """Passing partition_key to a non-partitioned Dag via REST trigger must return 400, not 500. + + The validation happens in TriggerDAGRunPostBody.validate_context(), which is now called + inside the try/except block that converts ValueError to HTTP 400. + """ + now = timezone.utcnow().isoformat() + response = test_client.post( + f"/dags/{DAG1_ID}/dagRuns", + json={"logical_date": now, "partition_key": "some-partition"}, + ) + assert response.status_code == 400 + assert "not a partitioned Dag" in response.json()["detail"] + + @pytest.mark.usefixtures("configure_git_connection_for_dag_bundle") + def test_should_respond_200_when_partition_key_given_for_partitioned_dag( + self, dag_maker, test_client, session + ): + """partition_key on a genuinely partitioned Dag must not be rejected (happy-path guard). + + Uses CronPartitionTimetable (partitioned=True) to confirm the reject path does not + fire for legitimate partitioned Dags. + """ + partitioned_dag_id = "test_partitioned_dag_trigger" + with dag_maker( + dag_id=partitioned_dag_id, + schedule=CronPartitionTimetable("0 * * * *", timezone="UTC"), + start_date=START_DATE1, + session=session, + serialized=True, + ): + EmptyOperator(task_id="task") + + session.commit() + + response = test_client.post( + f"/dags/{partitioned_dag_id}/dagRuns", + json={"logical_date": None, "partition_key": "2025-01-01T00:00:00"}, + ) + assert response.status_code == 200 + + @pytest.mark.usefixtures("configure_git_connection_for_dag_bundle") + def test_should_respond_200_when_partition_key_given_for_partition_at_runtime_dag( + self, dag_maker, test_client, session + ): + """partition_key on a PartitionAtRuntime Dag must also be accepted (deferred validation). + + partitioned_at_runtime=True means the Dag accepts runtime-discovered partition keys, so + the REST layer must not reject it even though timetable.partitioned is False. + """ + runtime_dag_id = "test_partition_at_runtime_dag_trigger" + with dag_maker( + dag_id=runtime_dag_id, + schedule=PartitionAtRuntime(), + start_date=START_DATE1, + session=session, + serialized=True, + ): + EmptyOperator(task_id="task") + + session.commit() + + response = test_client.post( + f"/dags/{runtime_dag_id}/dagRuns", + json={"logical_date": None, "partition_key": "runtime-key"}, + ) + assert response.status_code == 200 + class TestResolveRunOnLatestVersion: @pytest.mark.parametrize("explicit_value", [True, False]) @@ -2556,8 +2625,6 @@ class TestWaitDagRun: # test at least makes the tests run correctly. @pytest.fixture(autouse=True) def reconfigure_async_db_engine(self): - from airflow.settings import _configure_async_session - _configure_async_session() def test_should_respond_401(self, unauthenticated_test_client): @@ -2601,8 +2668,6 @@ def test_collect_task(self, test_client): assert data == {"state": DagRunState.SUCCESS, "results": {"task_1": '"result_1"'}} def test_should_respond_403_when_user_lacks_xcom_permission(self, test_client): - from airflow.api_fastapi.auth.managers.models.resource_details import DagAccessEntity, DagDetails - with mock.patch( "airflow.api_fastapi.core_api.routes.public.dag_run.get_auth_manager", autospec=True, diff --git a/airflow-core/tests/unit/models/test_taskinstance.py b/airflow-core/tests/unit/models/test_taskinstance.py index 3b4dceb628d03..990223c67ec95 100644 --- a/airflow-core/tests/unit/models/test_taskinstance.py +++ b/airflow-core/tests/unit/models/test_taskinstance.py @@ -104,6 +104,7 @@ from airflow.ti_deps.deps.base_ti_dep import TIDepStatus from airflow.ti_deps.deps.ready_to_reschedule import ReadyToRescheduleDep from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep, _UpstreamTIStates +from airflow.timetables.simple import PartitionAtRuntime from airflow.utils.session import create_session, provide_session from airflow.utils.span_status import SpanStatus from airflow.utils.state import DagRunState, State, TaskInstanceState @@ -3513,8 +3514,13 @@ def test_when_dag_run_has_partition_and_downstreams_listening_then_tables_popula assert pakl.target_dag_id == "asset_event_listener" -def test_runtime_partition_key_backfills_dag_run_when_none(dag_maker, session): - """Single runtime key on a PartitionAtRuntime-style run (dag_run.partition_key=None) back-fills the run.""" +def test_runtime_partition_key_does_not_backfill_dag_run_when_none(dag_maker, session): + """Task-emitted partition_key lands on the AssetEvent but does NOT back-fill DagRun.partition_key. + + DagRun.partition_key (provenance) is set by the scheduler / trigger side, not by task + runtime emission. A run that started with partition_key=None should remain None even when + a task emits an outlet event carrying its own key. + """ asset = Asset(name="hello") with dag_maker(dag_id="rt_pk_backfill", schedule=None) as dag: EmptyOperator(task_id="hi", outlets=[asset]) @@ -3533,7 +3539,7 @@ def test_runtime_partition_key_backfills_dag_run_when_none(dag_maker, session): event = session.scalar(select(AssetEvent).where(AssetEvent.source_dag_id == dag.dag_id)) assert event.partition_key == "us" session.refresh(dr) - assert dr.partition_key == "us" + assert dr.partition_key is None def test_runtime_partition_key_does_not_overwrite_scheduler_partition(dag_maker, session): @@ -3582,8 +3588,13 @@ def test_runtime_partition_keys_fan_out_to_one_event_per_key(dag_maker, session) assert dr.partition_key is None -def test_runtime_partition_key_is_none_when_event_has_no_key(dag_maker, session): - """An outlet event without partition_key produces an AssetEvent with partition_key=None.""" +def test_runtime_partition_key_inherits_dag_run_key_when_event_has_no_key(dag_maker, session): + """An outlet event without partition_key inherits DagRun.partition_key as the routing pointer. + + When a task emits an outlet event that carries no explicit partition_key, the resulting + AssetEvent should inherit the DagRun's partition_key so that downstream partitioned consumers + can still be routed correctly. + """ asset = Asset(name="hello") with dag_maker(dag_id="rt_pk_none", schedule=None) as dag: EmptyOperator(task_id="hi", outlets=[asset]) @@ -3599,11 +3610,11 @@ def test_runtime_partition_key_is_none_when_event_has_no_key(dag_maker, session) session=session, ) event = session.scalar(select(AssetEvent).where(AssetEvent.source_dag_id == dag.dag_id)) - assert event.partition_key is None + assert event.partition_key == "from-run" def test_runtime_partition_key_mixed_events_for_same_asset(dag_maker, session): - """One event with partition_key + one without produce two AssetEvents (key + None).""" + """One event with an explicit key + one without produce two AssetEvents (explicit key + inherited run key).""" asset = Asset(name="hello") with dag_maker(dag_id="rt_pk_mixed", schedule=None) as dag: EmptyOperator(task_id="hi", outlets=[asset]) @@ -3620,11 +3631,66 @@ def test_runtime_partition_key_mixed_events_for_same_asset(dag_maker, session): session=session, ) events = session.scalars(select(AssetEvent).where(AssetEvent.source_dag_id == dag.dag_id)).all() - assert {e.partition_key for e in events} == {"us", None} + assert {e.partition_key for e in events} == {"us", "from-run"} session.refresh(dr) assert dr.partition_key == "from-run" +def test_runtime_partition_key_event_stays_none_when_no_key_and_no_run_key(dag_maker, session): + """Task has no partition_key and run has no partition_key -> event.partition_key is None. + + Pins the `is not None` guard: both sides are None so effective_pk stays None and no + routing pointer is written to the AssetEvent. + """ + asset = Asset(name="hello") + with dag_maker(dag_id="rt_pk_both_none", schedule=None) as dag: + EmptyOperator(task_id="hi", outlets=[asset]) + dr = dag_maker.create_dagrun(session=session) + assert dr.partition_key is None + [ti] = dr.get_task_instances(session=session) + + TaskInstance.register_asset_changes_in_db( + ti=ti, + task_outlets=[ensure_serialized_asset(asset).asprofile()], + outlet_events=[ + {"dest_asset_key": {"name": "hello", "uri": "hello"}, "extra": {}}, + ], + session=session, + ) + event = session.scalar(select(AssetEvent).where(AssetEvent.source_dag_id == dag.dag_id)) + assert event.partition_key is None + + +def test_runtime_partition_key_does_not_backfill_partition_at_runtime_run(dag_maker, session): + """Task-emitted key lands on AssetEvent but does NOT back-fill DagRun.partition_key on a PartitionAtRuntime run. + + Provenance contract: DagRun.partition_key is set only at run-creation/trigger time. + A PartitionAtRuntime Dag triggered via REST without a partition_key starts with + partition_key=None. Even when a task discovers a partition at runtime and emits an + outlet event carrying an explicit key, DagRun.partition_key must remain None — the + key belongs to the AssetEvent, not to the run's provenance. + """ + asset = Asset(name="hello") + with dag_maker(dag_id="rt_pk_par_backfill", schedule=PartitionAtRuntime()) as dag: + EmptyOperator(task_id="hi", outlets=[asset]) + dr = dag_maker.create_dagrun(session=session) + assert dr.partition_key is None + [ti] = dr.get_task_instances(session=session) + + TaskInstance.register_asset_changes_in_db( + ti=ti, + task_outlets=[ensure_serialized_asset(asset).asprofile()], + outlet_events=[ + {"dest_asset_key": {"name": "hello", "uri": "hello"}, "extra": {}, "partition_key": "2025-01-01"}, + ], + session=session, + ) + event = session.scalar(select(AssetEvent).where(AssetEvent.source_dag_id == dag.dag_id)) + assert event.partition_key == "2025-01-01" + session.refresh(dr) + assert dr.partition_key is None + + def test_when_runtime_partition_keys_and_downstreams_listening_then_tables_populated( dag_maker, session,