diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py index bf20bcbc30c55..ad051b3e6d340 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py @@ -317,6 +317,7 @@ class AssetEventDagRunReference(StrictBaseModel): source_map_index: int | None source_aliases: list[AssetAliasReferenceAssetEventDagRun] timestamp: UtcDateTime + partition_key: str | None = None class DagRun(StrictBaseModel): diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_04_06.py b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_04_06.py index 5a3b0f0f5fc2d..e3b995011f4cf 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_04_06.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_04_06.py @@ -29,6 +29,7 @@ ) from airflow.api_fastapi.execution_api.datamodels.dagrun import TriggerDAGRunPayload from airflow.api_fastapi.execution_api.datamodels.taskinstance import ( + AssetEventDagRunReference, DagRun, TIDeferredStatePayload, TIRunContext, @@ -43,6 +44,7 @@ class AddPartitionKeyField(VersionChange): instructions_to_migrate_to_previous_version = ( schema(DagRun).field("partition_key").didnt_exist, schema(AssetEventResponse).field("partition_key").didnt_exist, + schema(AssetEventDagRunReference).field("partition_key").didnt_exist, schema(TriggerDAGRunPayload).field("partition_key").didnt_exist, schema(DagRunAssetReference).field("partition_key").didnt_exist, ) @@ -50,8 +52,12 @@ class AddPartitionKeyField(VersionChange): @convert_response_to_previous_version_for(TIRunContext) # type: ignore[arg-type] def remove_partition_key_from_dag_run(response: ResponseInfo) -> None: # type: ignore[misc] """Remove the `partition_key` field from the dag_run object when converting to the previous version.""" - if "dag_run" in response.body and isinstance(response.body["dag_run"], dict): - response.body["dag_run"].pop("partition_key", None) + dag_run = response.body.get("dag_run") + if isinstance(dag_run, dict): + dag_run.pop("partition_key", None) + for event in dag_run.get("consumed_asset_events") or (): + if isinstance(event, dict): + event.pop("partition_key", None) @convert_response_to_previous_version_for(AssetEventsResponse) # type: ignore[arg-type] def remove_partition_key_from_asset_events(response: ResponseInfo) -> None: # type: ignore[misc] diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py index 90325bb9c8542..67e0a4ee4aee7 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py @@ -164,12 +164,44 @@ def setup_method(self): clear_db_runs() clear_db_serialized_dags() clear_db_dags() + clear_db_assets() def teardown_method(self): clear_db_logs() clear_db_runs() clear_db_serialized_dags() clear_db_dags() + clear_db_assets() + + def test_ti_run_context_exposes_consumed_event_partition_key(self, client, session, create_task_instance): + """The partition key of each consumed asset event is returned in the run context.""" + ti = create_task_instance( + task_id="test_consumed_event_partition_key", + state=State.QUEUED, + session=session, + ) + asset = AssetModel(name="upstream", uri="s3://bucket/upstream", group="asset", extra={}) + session.add_all([asset, AssetActive.for_asset(asset)]) + session.flush() + ti.dag_run.consumed_asset_events.append( + AssetEvent(asset_id=asset.id, source_dag_id="src", source_run_id="r1", partition_key="2024-01-15") + ) + session.commit() + + response = client.patch( + f"/execution/task-instances/{ti.id}/run", + json={ + "state": "running", + "hostname": "h", + "unixname": "u", + "pid": 1, + "start_date": "2024-09-30T12:00:00Z", + }, + ) + + assert response.status_code == 200 + events = response.json()["dag_run"]["consumed_asset_events"] + assert [e["partition_key"] for e in events] == ["2024-01-15"] @pytest.mark.parametrize( ("max_tries", "should_retry"), diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2026_04_06/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2026_04_06/test_task_instances.py index d855291727756..4a8b59c362e21 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2026_04_06/test_task_instances.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2026_04_06/test_task_instances.py @@ -20,10 +20,11 @@ import pytest from airflow._shared.timezones import timezone +from airflow.models.asset import AssetActive, AssetEvent, AssetModel from airflow.serialization.serialized_objects import BaseSerialization from airflow.utils.state import DagRunState, State -from tests_common.test_utils.db import clear_db_runs +from tests_common.test_utils.db import clear_db_assets, clear_db_runs from tests_common.test_utils.format_datetime import from_datetime_to_zulu_without_ms pytestmark = pytest.mark.db_test @@ -253,3 +254,54 @@ def test_head_version_returns_raw_serde_format(self, client, session, create_tas assert response.status_code == 200 # Head version gets the plain dict directly -- no BaseSerialization wrapping assert response.json()["next_kwargs"] == {"cheesecake": True, "event": "payload"} + + +class TestConsumedEventPartitionKeyBackwardCompat: + """The partition_key on consumed asset events is stripped for pre-2026-04-06 clients.""" + + @pytest.fixture(autouse=True) + def _freeze_time(self, time_machine): + time_machine.move_to(TIMESTAMP_STR, tick=False) + + def setup_method(self): + clear_db_runs() + clear_db_assets() + + def teardown_method(self): + clear_db_runs() + clear_db_assets() + + def _create_ti_with_consumed_event(self, session, create_task_instance): + ti = create_task_instance( + task_id="test_consumed_event_partition_key_compat", + state=State.QUEUED, + session=session, + start_date=TIMESTAMP, + ) + asset = AssetModel(name="upstream", uri="s3://bucket/upstream", group="asset", extra={}) + session.add_all([asset, AssetActive.for_asset(asset)]) + session.flush() + ti.dag_run.consumed_asset_events.append( + AssetEvent(asset_id=asset.id, source_dag_id="src", source_run_id="r1", partition_key="2024-01-15") + ) + session.commit() + return ti + + def test_old_version_strips_partition_key(self, old_ver_client, session, create_task_instance): + ti = self._create_ti_with_consumed_event(session, create_task_instance) + + response = old_ver_client.patch(f"/execution/task-instances/{ti.id}/run", json=RUN_PATCH_BODY) + + assert response.status_code == 200 + events = response.json()["dag_run"]["consumed_asset_events"] + assert events + assert all("partition_key" not in event for event in events) + + def test_head_version_keeps_partition_key(self, client, session, create_task_instance): + ti = self._create_ti_with_consumed_event(session, create_task_instance) + + response = client.patch(f"/execution/task-instances/{ti.id}/run", json=RUN_PATCH_BODY) + + assert response.status_code == 200 + events = response.json()["dag_run"]["consumed_asset_events"] + assert [event["partition_key"] for event in events] == ["2024-01-15"] diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py index bc03569386d14..80cee82cb1eaa 100644 --- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -725,6 +725,7 @@ class AssetEventDagRunReference(BaseModel): source_map_index: Annotated[int | None, Field(title="Source Map Index")] = None source_aliases: Annotated[list[AssetAliasReferenceAssetEventDagRun], Field(title="Source Aliases")] timestamp: Annotated[AwareDatetime, Field(title="Timestamp")] + partition_key: Annotated[str | None, Field(title="Partition Key")] = None class AssetEventResponse(BaseModel): diff --git a/task-sdk/src/airflow/sdk/execution_time/schema/schema.json b/task-sdk/src/airflow/sdk/execution_time/schema/schema.json index 087d149d1af03..08383b1422204 100644 --- a/task-sdk/src/airflow/sdk/execution_time/schema/schema.json +++ b/task-sdk/src/airflow/sdk/execution_time/schema/schema.json @@ -4502,6 +4502,18 @@ "format": "date-time", "title": "Timestamp", "type": "string" + }, + "partition_key": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Partition Key" } }, "required": [ diff --git a/task-sdk/tests/task_sdk/execution_time/test_context.py b/task-sdk/tests/task_sdk/execution_time/test_context.py index f82f4924fbafc..eb1ffa0013e22 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_context.py +++ b/task-sdk/tests/task_sdk/execution_time/test_context.py @@ -660,6 +660,24 @@ def test_getitem_uri_ref( assert mock_supervisor_comms.send.mock_calls == [mock.call(GetAssetByUri(uri=uri))] assert _AssetRefResolutionMixin._asset_ref_cache + def test_partition_key_exposed(self): + """A consumed asset event's partition key is reachable via triggering_asset_events.""" + event = { + "asset": {"name": "1", "uri": "1", "extra": {}}, + "extra": {}, + "source_task_id": "t1", + "source_dag_id": "d1", + "source_run_id": "r1", + "source_map_index": -1, + "source_aliases": [], + "timestamp": "2025-01-01T00:00:12Z", + "partition_key": "2024-01-15", + } + accessor = TriggeringAssetEventsAccessor.build( + [AssetEventDagRunReferenceResult.model_validate(event)] + ) + assert [e.partition_key for e in accessor[Asset("1")]] == ["2024-01-15"] + def test_source_task_instance_xcom_pull(self, mock_supervisor_comms, accessor): events = accessor[Asset("2")] assert len(events) == 1