Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@ class AssetEventDagRunReference(StrictBaseModel):
source_map_index: int | None
source_aliases: list[AssetAliasReferenceAssetEventDagRun]
timestamp: UtcDateTime
partition_key: str | None = None


class DagRun(StrictBaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
)
from airflow.api_fastapi.execution_api.datamodels.dagrun import TriggerDAGRunPayload
from airflow.api_fastapi.execution_api.datamodels.taskinstance import (
AssetEventDagRunReference,
DagRun,
TIDeferredStatePayload,
TIRunContext,
Expand All @@ -43,15 +44,20 @@ class AddPartitionKeyField(VersionChange):
instructions_to_migrate_to_previous_version = (
schema(DagRun).field("partition_key").didnt_exist,
schema(AssetEventResponse).field("partition_key").didnt_exist,
schema(AssetEventDagRunReference).field("partition_key").didnt_exist,
schema(TriggerDAGRunPayload).field("partition_key").didnt_exist,
schema(DagRunAssetReference).field("partition_key").didnt_exist,
)

@convert_response_to_previous_version_for(TIRunContext) # type: ignore[arg-type]
def remove_partition_key_from_dag_run(response: ResponseInfo) -> None: # type: ignore[misc]
"""Remove the `partition_key` field from the dag_run object when converting to the previous version."""
if "dag_run" in response.body and isinstance(response.body["dag_run"], dict):
response.body["dag_run"].pop("partition_key", None)
dag_run = response.body.get("dag_run")
if isinstance(dag_run, dict):
dag_run.pop("partition_key", None)
for event in dag_run.get("consumed_asset_events") or ():
if isinstance(event, dict):
event.pop("partition_key", None)

@convert_response_to_previous_version_for(AssetEventsResponse) # type: ignore[arg-type]
def remove_partition_key_from_asset_events(response: ResponseInfo) -> None: # type: ignore[misc]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,12 +164,44 @@ def setup_method(self):
clear_db_runs()
clear_db_serialized_dags()
clear_db_dags()
clear_db_assets()

def teardown_method(self):
clear_db_logs()
clear_db_runs()
clear_db_serialized_dags()
clear_db_dags()
clear_db_assets()

def test_ti_run_context_exposes_consumed_event_partition_key(self, client, session, create_task_instance):
"""The partition key of each consumed asset event is returned in the run context."""
ti = create_task_instance(
task_id="test_consumed_event_partition_key",
state=State.QUEUED,
session=session,
)
asset = AssetModel(name="upstream", uri="s3://bucket/upstream", group="asset", extra={})
session.add_all([asset, AssetActive.for_asset(asset)])
session.flush()
ti.dag_run.consumed_asset_events.append(
AssetEvent(asset_id=asset.id, source_dag_id="src", source_run_id="r1", partition_key="2024-01-15")
)
session.commit()

response = client.patch(
f"/execution/task-instances/{ti.id}/run",
json={
"state": "running",
"hostname": "h",
"unixname": "u",
"pid": 1,
"start_date": "2024-09-30T12:00:00Z",
},
)

assert response.status_code == 200
events = response.json()["dag_run"]["consumed_asset_events"]
assert [e["partition_key"] for e in events] == ["2024-01-15"]

@pytest.mark.parametrize(
("max_tries", "should_retry"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@
import pytest

from airflow._shared.timezones import timezone
from airflow.models.asset import AssetActive, AssetEvent, AssetModel
from airflow.serialization.serialized_objects import BaseSerialization
from airflow.utils.state import DagRunState, State

from tests_common.test_utils.db import clear_db_runs
from tests_common.test_utils.db import clear_db_assets, clear_db_runs
from tests_common.test_utils.format_datetime import from_datetime_to_zulu_without_ms

pytestmark = pytest.mark.db_test
Expand Down Expand Up @@ -253,3 +254,54 @@ def test_head_version_returns_raw_serde_format(self, client, session, create_tas
assert response.status_code == 200
# Head version gets the plain dict directly -- no BaseSerialization wrapping
assert response.json()["next_kwargs"] == {"cheesecake": True, "event": "payload"}


class TestConsumedEventPartitionKeyBackwardCompat:
"""The partition_key on consumed asset events is stripped for pre-2026-04-06 clients."""

@pytest.fixture(autouse=True)
def _freeze_time(self, time_machine):
time_machine.move_to(TIMESTAMP_STR, tick=False)

def setup_method(self):
clear_db_runs()
clear_db_assets()

def teardown_method(self):
clear_db_runs()
clear_db_assets()

def _create_ti_with_consumed_event(self, session, create_task_instance):
ti = create_task_instance(
task_id="test_consumed_event_partition_key_compat",
state=State.QUEUED,
session=session,
start_date=TIMESTAMP,
)
asset = AssetModel(name="upstream", uri="s3://bucket/upstream", group="asset", extra={})
session.add_all([asset, AssetActive.for_asset(asset)])
session.flush()
ti.dag_run.consumed_asset_events.append(
AssetEvent(asset_id=asset.id, source_dag_id="src", source_run_id="r1", partition_key="2024-01-15")
)
session.commit()
return ti

def test_old_version_strips_partition_key(self, old_ver_client, session, create_task_instance):
ti = self._create_ti_with_consumed_event(session, create_task_instance)

response = old_ver_client.patch(f"/execution/task-instances/{ti.id}/run", json=RUN_PATCH_BODY)

assert response.status_code == 200
events = response.json()["dag_run"]["consumed_asset_events"]
assert events
assert all("partition_key" not in event for event in events)

def test_head_version_keeps_partition_key(self, client, session, create_task_instance):
ti = self._create_ti_with_consumed_event(session, create_task_instance)

response = client.patch(f"/execution/task-instances/{ti.id}/run", json=RUN_PATCH_BODY)

assert response.status_code == 200
events = response.json()["dag_run"]["consumed_asset_events"]
assert [event["partition_key"] for event in events] == ["2024-01-15"]
1 change: 1 addition & 0 deletions task-sdk/src/airflow/sdk/api/datamodels/_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,6 +725,7 @@ class AssetEventDagRunReference(BaseModel):
source_map_index: Annotated[int | None, Field(title="Source Map Index")] = None
source_aliases: Annotated[list[AssetAliasReferenceAssetEventDagRun], Field(title="Source Aliases")]
timestamp: Annotated[AwareDatetime, Field(title="Timestamp")]
partition_key: Annotated[str | None, Field(title="Partition Key")] = None


class AssetEventResponse(BaseModel):
Expand Down
12 changes: 12 additions & 0 deletions task-sdk/src/airflow/sdk/execution_time/schema/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -4502,6 +4502,18 @@
"format": "date-time",
"title": "Timestamp",
"type": "string"
},
"partition_key": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": null,
"title": "Partition Key"
}
},
"required": [
Expand Down
18 changes: 18 additions & 0 deletions task-sdk/tests/task_sdk/execution_time/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,24 @@ def test_getitem_uri_ref(
assert mock_supervisor_comms.send.mock_calls == [mock.call(GetAssetByUri(uri=uri))]
assert _AssetRefResolutionMixin._asset_ref_cache

def test_partition_key_exposed(self):
"""A consumed asset event's partition key is reachable via triggering_asset_events."""
event = {
"asset": {"name": "1", "uri": "1", "extra": {}},
"extra": {},
"source_task_id": "t1",
"source_dag_id": "d1",
"source_run_id": "r1",
"source_map_index": -1,
"source_aliases": [],
"timestamp": "2025-01-01T00:00:12Z",
"partition_key": "2024-01-15",
}
accessor = TriggeringAssetEventsAccessor.build(
[AssetEventDagRunReferenceResult.model_validate(event)]
)
assert [e.partition_key for e in accessor[Asset("1")]] == ["2024-01-15"]

def test_source_task_instance_xcom_pull(self, mock_supervisor_comms, accessor):
events = accessor[Asset("2")]
assert len(events) == 1
Expand Down
Loading