Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -397,8 +397,16 @@ async def get_task_state(self):
run_ids=[self.task_instance.run_id],
map_index=self.task_instance.map_index,
)
# The /states endpoint suffixes the response key with ``_{map_index}`` for mapped TIs
# (see ``get_task_instance_states`` in airflow-core's execution_api routes); non-mapped
# TIs keep the plain ``task_id``.
ti_key = (
f"{self.task_instance.task_id}_{self.task_instance.map_index}"
if self.task_instance.map_index >= 0
else self.task_instance.task_id
)
try:
return task_states_response[self.task_instance.run_id][self.task_instance.task_id]
return task_states_response[self.task_instance.run_id][ti_key]
except KeyError:
raise AirflowException(
"TaskInstance with dag_id: %s, task_id: %s, run_id: %s and map_index: %s is not found",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from airflow.triggers.base import TriggerEvent
from airflow.utils.state import TaskInstanceState

from tests_common.test_utils.version_compat import AIRFLOW_V_3_3_PLUS
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_3_PLUS

TRIGGER_PATH = "airflow.providers.cncf.kubernetes.triggers.pod.KubernetesPodTrigger"
HOOK_PATH = "airflow.providers.cncf.kubernetes.hooks.kubernetes.AsyncKubernetesHook"
Expand Down Expand Up @@ -827,6 +827,90 @@ async def test_safe_to_cancel_returns_false_when_task_still_deferred(self, mock_
)
assert await trigger.safe_to_cancel() is False

@pytest.mark.skipif(
not AIRFLOW_V_3_0_PLUS,
reason="get_task_state uses RuntimeTaskInstance.get_task_states on Airflow 3.0+",
)
@pytest.mark.asyncio
@mock.patch("airflow.sdk.execution_time.task_runner.RuntimeTaskInstance.get_task_states")
async def test_get_task_state_uses_task_id_for_non_mapped_ti(self, mock_get_task_states):
# Non-mapped TIs (``map_index < 0``) are keyed by plain ``task_id`` in the
# response, matching the dict-key construction in the execution API's
# ``get_task_instance_states`` handler.
run_id = "manual__2026-05-21T00:00:00+00:00"
mock_get_task_states.return_value = {run_id: {"my_task": TaskInstanceState.SUCCESS}}

trigger = KubernetesPodTrigger(
pod_name=POD_NAME,
pod_namespace=NAMESPACE,
base_container_name=BASE_CONTAINER_NAME,
trigger_start_time=TRIGGER_START_TIME,
schedule_timeout=STARTUP_TIMEOUT_SECS,
)
trigger.task_instance = MagicMock(dag_id="my_dag", task_id="my_task", run_id=run_id, map_index=-1)

assert await trigger.get_task_state() == TaskInstanceState.SUCCESS

@pytest.mark.skipif(
not AIRFLOW_V_3_0_PLUS,
reason="get_task_state uses RuntimeTaskInstance.get_task_states on Airflow 3.0+",
)
@pytest.mark.asyncio
@mock.patch("airflow.sdk.execution_time.task_runner.RuntimeTaskInstance.get_task_states")
async def test_get_task_state_uses_composite_key_for_mapped_ti(self, mock_get_task_states):
# Regression guard for #67296: mapped TIs (``map_index >= 0``) are
# keyed by ``f"{task_id}_{map_index}"`` in the response. Without the
# suffix this lookup would KeyError, which ``cleanup()`` would
# defensively swallow and skip ``hook.delete_pod()`` -- leaking the
# pod until ``active_deadline_seconds`` expires on user mark-failed.
run_id = "manual__2026-05-21T00:00:00+00:00"
mock_get_task_states.return_value = {run_id: {"map_group.task_a_2": TaskInstanceState.FAILED}}

trigger = KubernetesPodTrigger(
pod_name=POD_NAME,
pod_namespace=NAMESPACE,
base_container_name=BASE_CONTAINER_NAME,
trigger_start_time=TRIGGER_START_TIME,
schedule_timeout=STARTUP_TIMEOUT_SECS,
)
trigger.task_instance = MagicMock(
dag_id="my_dag", task_id="map_group.task_a", run_id=run_id, map_index=2
)

assert await trigger.get_task_state() == TaskInstanceState.FAILED

@pytest.mark.skipif(
not AIRFLOW_V_3_0_PLUS,
reason="get_task_state uses RuntimeTaskInstance.get_task_states on Airflow 3.0+",
)
@pytest.mark.asyncio
@mock.patch("airflow.sdk.execution_time.task_runner.RuntimeTaskInstance.get_task_states")
async def test_get_task_state_raises_when_mapped_key_missing(self, mock_get_task_states):
# The wrapped ``AirflowException`` shape is preserved when the
# response is missing the expected (composite) key, so callers
# like ``safe_to_cancel`` keep the same behaviour they had before
# the lookup was fixed.
from airflow.exceptions import AirflowException

run_id = "manual__2026-05-21T00:00:00+00:00"
# Response has the run_id but not the (``map_group.task_a``, ``2``)
# entry -- e.g. supervisor has not observed the TI yet.
mock_get_task_states.return_value = {run_id: {"map_group.task_a_5": "running"}}

trigger = KubernetesPodTrigger(
pod_name=POD_NAME,
pod_namespace=NAMESPACE,
base_container_name=BASE_CONTAINER_NAME,
trigger_start_time=TRIGGER_START_TIME,
schedule_timeout=STARTUP_TIMEOUT_SECS,
)
trigger.task_instance = MagicMock(
dag_id="my_dag", task_id="map_group.task_a", run_id=run_id, map_index=2
)

with pytest.raises(AirflowException, match="TaskInstance with dag_id"):
await trigger.get_task_state()

@pytest.mark.skipif(
AIRFLOW_V_3_3_PLUS,
reason="Legacy cleanup path runs only on Airflow < 3.3",
Expand Down
Loading