Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion airflow/api_fastapi/execution_api/datamodels/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from datetime import timedelta
from typing import Annotated, Any, Literal, Union

from pydantic import Discriminator, Field, Tag, WithJsonSchema
from pydantic import AwareDatetime, Discriminator, Field, Tag, TypeAdapter, WithJsonSchema, field_validator

from airflow.api_fastapi.common.types import UtcDateTime
from airflow.api_fastapi.core_api.base import BaseModel
Expand All @@ -30,6 +30,8 @@
from airflow.utils.state import IntermediateTIState, TaskInstanceState as TIState, TerminalTIState
from airflow.utils.types import DagRunType

AwareDatetimeAdapter = TypeAdapter(AwareDatetime)


class TIEnterRunningPayload(BaseModel):
"""Schema for updating TaskInstance to 'RUNNING' state with minimal required fields."""
Expand Down Expand Up @@ -83,6 +85,12 @@ class TIDeferredStatePayload(BaseModel):
next_method: str
trigger_timeout: timedelta | None = None

@field_validator("trigger_kwargs")
def validate_moment(cls, v):
if "moment" in v:
v["moment"] = AwareDatetimeAdapter.validate_strings(v["moment"])
return v


def ti_state_discriminator(v: dict[str, str] | BaseModel) -> str:
"""
Expand Down
1 change: 1 addition & 0 deletions airflow/api_fastapi/execution_api/routes/task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ def ti_update_state(
kwargs=ti_patch_payload.trigger_kwargs,
)
session.add(trigger_row)
session.flush()

# TODO: HANDLE execution timeout later as it requires a call to the DB
# either get it from the serialised DAG or get it from the API
Expand Down
12 changes: 9 additions & 3 deletions tests/api_fastapi/execution_api/routes/test_task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def test_ti_update_state_to_deferred(self, client, session, create_task_instance

payload = {
"state": "deferred",
"trigger_kwargs": {"key": "value"},
"trigger_kwargs": {"key": "value", "moment": "2024-12-18T00:00:00Z"},
"classpath": "my-classpath",
"next_method": "execute_callback",
"trigger_timeout": "P1D", # 1 day
Expand All @@ -277,14 +277,20 @@ def test_ti_update_state_to_deferred(self, client, session, create_task_instance

assert tis[0].state == TaskInstanceState.DEFERRED
assert tis[0].next_method == "execute_callback"
assert tis[0].next_kwargs == {"key": "value"}
assert tis[0].next_kwargs == {
"key": "value",
"moment": datetime(2024, 12, 18, 00, 00, 00, tzinfo=timezone.utc),
}
assert tis[0].trigger_timeout == timezone.make_aware(datetime(2024, 11, 23), timezone=timezone.utc)

t = session.query(Trigger).all()
assert len(t) == 1
assert t[0].created_date == instant
assert t[0].classpath == "my-classpath"
assert t[0].kwargs == {"key": "value"}
assert t[0].kwargs == {
"key": "value",
"moment": datetime(2024, 12, 18, 00, 00, 00, tzinfo=timezone.utc),
}


class TestTIHealthEndpoint:
Expand Down