Skip to content

Commit

Permalink
Improve clear_not_launched_queued_tasks call duration (#34985)
Browse files Browse the repository at this point in the history
* Improve clear_not_launched_queued_tasks call duration

* Apply suggestions from code review

Co-authored-by: Tzu-ping Chung <uranusjr@gmail.com>

---------

Co-authored-by: gopal <gopal_dirisala@apple.com>
Co-authored-by: Tzu-ping Chung <uranusjr@gmail.com>
Co-authored-by: Hussein Awala <hussein@awala.fr>
Co-authored-by: Elad Kalif <45845474+eladkal@users.noreply.github.com>
  • Loading branch information
5 people committed Nov 1, 2023
1 parent f84c458 commit 3724a02
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 46 deletions.
47 changes: 36 additions & 11 deletions airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,36 @@ def clear_not_launched_queued_tasks(self, session: Session = NEW_SESSION) -> Non
if time.time() - timestamp > allowed_age:
del self.last_handled[key]

if not queued_tis:
return

# airflow worker label selector batch call
kwargs = {"label_selector": f"airflow-worker={self._make_safe_label_value(str(self.job_id))}"}
if self.kube_config.kube_client_request_args:
kwargs.update(self.kube_config.kube_client_request_args)
pod_list = self._list_pods(kwargs)

# create a set against pod query label fields
label_search_set = set()
for pod in pod_list:
dag_id = pod.metadata.labels.get("dag_id", None)
task_id = pod.metadata.labels.get("task_id", None)
airflow_worker = pod.metadata.labels.get("airflow-worker", None)
map_index = pod.metadata.labels.get("map_index", None)
run_id = pod.metadata.labels.get("run_id", None)
execution_date = pod.metadata.labels.get("execution_date", None)
if dag_id is None or task_id is None or airflow_worker is None:
continue
label_search_base_str = f"dag_id={dag_id},task_id={task_id},airflow-worker={airflow_worker}"
if map_index is not None:
label_search_base_str += f",map_index={map_index}"
if run_id is not None:
label_search_str = f"{label_search_base_str},run_id={run_id}"
label_search_set.add(label_search_str)
if execution_date is not None:
label_search_str = f"{label_search_base_str},execution_date={execution_date}"
label_search_set.add(label_search_str)

for ti in queued_tis:
self.log.debug("Checking task instance %s", ti)

Expand All @@ -240,21 +270,16 @@ def clear_not_launched_queued_tasks(self, session: Session = NEW_SESSION) -> Non
if ti.map_index >= 0:
# Old tasks _couldn't_ be mapped, so we don't have to worry about compat
base_label_selector += f",map_index={ti.map_index}"
kwargs = {"label_selector": base_label_selector}
if self.kube_config.kube_client_request_args:
kwargs.update(**self.kube_config.kube_client_request_args)

# Try run_id first
kwargs["label_selector"] += ",run_id=" + self._make_safe_label_value(ti.run_id)
pod_list = self._list_pods(kwargs)
if pod_list:
label_search_str = f"{base_label_selector},run_id={self._make_safe_label_value(ti.run_id)}"
if label_search_str in label_search_set:
continue
# Fallback to old style of using execution_date
kwargs[
"label_selector"
] = f"{base_label_selector},execution_date={self._make_safe_label_value(ti.execution_date)}"
pod_list = self._list_pods(kwargs)
if pod_list:
label_search_str = (
f"{base_label_selector},execution_date={self._make_safe_label_value(ti.execution_date)}"
)
if label_search_str in label_search_set:
continue
self.log.info("TaskInstance: %s found in queued state but was not launched, rescheduling", ti)
session.execute(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1022,17 +1022,9 @@ def test_clear_not_launched_queued_tasks_not_launched(self, dag_maker, create_du

ti.refresh_from_db()
assert ti.state == State.SCHEDULED
assert mock_kube_client.list_namespaced_pod.call_count == 2
mock_kube_client.list_namespaced_pod.assert_any_call(
namespace="default", label_selector="dag_id=test_clear,task_id=task1,airflow-worker=1,run_id=test"
)
# also check that we fall back to execution_date if we didn't find the pod with run_id
execution_date_label = pod_generator.datetime_to_label_safe_datestring(ti.execution_date)
assert mock_kube_client.list_namespaced_pod.call_count == 1
mock_kube_client.list_namespaced_pod.assert_called_with(
namespace="default",
label_selector=(
f"dag_id=test_clear,task_id=task1,airflow-worker=1,execution_date={execution_date_label}"
),
namespace="default", label_selector="airflow-worker=1"
)

@pytest.mark.db_test
Expand All @@ -1049,7 +1041,22 @@ def test_clear_not_launched_queued_tasks_launched(
):
"""Leave the state alone if a pod already exists"""
mock_kube_client = mock.MagicMock()
mock_kube_client.list_namespaced_pod.return_value = k8s.V1PodList(items=["something"])
mock_kube_client.list_namespaced_pod.return_value = k8s.V1PodList(
items=[
k8s.V1Pod(
metadata=k8s.V1ObjectMeta(
labels={
"role": "airflow-worker",
"dag_id": "test_clear",
"task_id": "task1",
"airflow-worker": 1,
"run_id": "test",
},
),
status=k8s.V1PodStatus(phase="Pending"),
)
]
)

create_dummy_dag(dag_id="test_clear", task_id="task1", with_dagrun_type=None)
dag_run = dag_maker.create_dagrun()
Expand All @@ -1069,18 +1076,31 @@ def test_clear_not_launched_queued_tasks_launched(
ti.refresh_from_db()
assert ti.state == State.QUEUED
mock_kube_client.list_namespaced_pod.assert_called_once_with(
namespace="default", label_selector="dag_id=test_clear,task_id=task1,airflow-worker=1,run_id=test"
namespace="default", label_selector="airflow-worker=1"
)

@pytest.mark.db_test
def test_clear_not_launched_queued_tasks_mapped_task(self, dag_maker, session):
"""One mapped task has a launched pod - other does not."""

def list_namespaced_pod(*args, **kwargs):
if "map_index=0" in kwargs["label_selector"]:
return k8s.V1PodList(items=["something"])
else:
return k8s.V1PodList(items=[])
return k8s.V1PodList(
items=[
k8s.V1Pod(
metadata=k8s.V1ObjectMeta(
labels={
"role": "airflow-worker",
"dag_id": "test_clear",
"task_id": "bash",
"airflow-worker": 1,
"map_index": 0,
"run_id": "test",
},
),
status=k8s.V1PodStatus(phase="Pending"),
)
]
)

mock_kube_client = mock.MagicMock()
mock_kube_client.list_namespaced_pod.side_effect = list_namespaced_pod
Expand Down Expand Up @@ -1109,25 +1129,10 @@ def list_namespaced_pod(*args, **kwargs):
assert ti0.state == State.QUEUED
assert ti1.state == State.SCHEDULED

assert mock_kube_client.list_namespaced_pod.call_count == 3
execution_date_label = pod_generator.datetime_to_label_safe_datestring(dag_run.execution_date)
mock_kube_client.list_namespaced_pod.assert_has_calls(
[
mock.call(
namespace="default",
label_selector="dag_id=test_clear,task_id=bash,airflow-worker=1,map_index=0,run_id=test",
),
mock.call(
namespace="default",
label_selector="dag_id=test_clear,task_id=bash,airflow-worker=1,map_index=1,run_id=test",
),
mock.call(
namespace="default",
label_selector=f"dag_id=test_clear,task_id=bash,airflow-worker=1,map_index=1,"
f"execution_date={execution_date_label}",
),
],
any_order=True,
assert mock_kube_client.list_namespaced_pod.call_count == 1
mock_kube_client.list_namespaced_pod.assert_called_with(
namespace="default",
label_selector="airflow-worker=1",
)

@pytest.mark.db_test
Expand Down

0 comments on commit 3724a02

Please sign in to comment.