Skip to content

Commit

Permalink
Fix external_executor_id being overwritten (apache#37784)
Browse files Browse the repository at this point in the history
  • Loading branch information
ferruzzi authored and utkarsharma2 committed Apr 22, 2024
1 parent 8ebf024 commit e45c12f
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 2 deletions.
5 changes: 4 additions & 1 deletion airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2215,7 +2215,10 @@ def _check_and_change_state_before_execution(

ti.state = TaskInstanceState.RUNNING
ti.emit_state_change_metric(TaskInstanceState.RUNNING)
ti.external_executor_id = external_executor_id

if external_executor_id:
ti.external_executor_id = external_executor_id

ti.end_date = None
if not test_mode:
session.merge(ti).task = task
Expand Down
2 changes: 2 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -916,6 +916,7 @@ def maker(
run_id=None,
run_type=None,
data_interval=None,
external_executor_id=None,
map_index=-1,
**kwargs,
) -> TaskInstance:
Expand All @@ -936,6 +937,7 @@ def maker(
(ti,) = dagrun.task_instances
ti.task = task
ti.state = state
ti.external_executor_id = external_executor_id
ti.map_index = map_index

dag_maker.session.flush()
Expand Down
47 changes: 46 additions & 1 deletion tests/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -1782,7 +1782,11 @@ def post_execute(self, context, result=None):
ti.run()

def test_check_and_change_state_before_execution(self, create_task_instance):
ti = create_task_instance(dag_id="test_check_and_change_state_before_execution")
expected_external_executor_id = "banana"
ti = create_task_instance(
dag_id="test_check_and_change_state_before_execution",
external_executor_id=expected_external_executor_id,
)
SerializedDagModel.write_dag(ti.task.dag)

serialized_dag = SerializedDagModel.get(ti.task.dag.dag_id).dag
Expand All @@ -1791,6 +1795,46 @@ def test_check_and_change_state_before_execution(self, create_task_instance):
assert ti_from_deserialized_task._try_number == 0
assert ti_from_deserialized_task.check_and_change_state_before_execution()
# State should be running, and try_number column should be incremented
assert ti_from_deserialized_task.external_executor_id == expected_external_executor_id
assert ti_from_deserialized_task.state == State.RUNNING
assert ti_from_deserialized_task._try_number == 1

def test_check_and_change_state_before_execution_provided_id_overrides(self, create_task_instance):
expected_external_executor_id = "banana"
ti = create_task_instance(
dag_id="test_check_and_change_state_before_execution",
external_executor_id="apple",
)
assert ti.external_executor_id == "apple"
SerializedDagModel.write_dag(ti.task.dag)

serialized_dag = SerializedDagModel.get(ti.task.dag.dag_id).dag
ti_from_deserialized_task = TI(task=serialized_dag.get_task(ti.task_id), run_id=ti.run_id)

assert ti_from_deserialized_task._try_number == 0
assert ti_from_deserialized_task.check_and_change_state_before_execution(
external_executor_id=expected_external_executor_id
)
# State should be running, and try_number column should be incremented
assert ti_from_deserialized_task.external_executor_id == expected_external_executor_id
assert ti_from_deserialized_task.state == State.RUNNING
assert ti_from_deserialized_task._try_number == 1

def test_check_and_change_state_before_execution_with_exec_id(self, create_task_instance):
expected_external_executor_id = "minions"
ti = create_task_instance(dag_id="test_check_and_change_state_before_execution")
assert ti.external_executor_id is None
SerializedDagModel.write_dag(ti.task.dag)

serialized_dag = SerializedDagModel.get(ti.task.dag.dag_id).dag
ti_from_deserialized_task = TI(task=serialized_dag.get_task(ti.task_id), run_id=ti.run_id)

assert ti_from_deserialized_task._try_number == 0
assert ti_from_deserialized_task.check_and_change_state_before_execution(
external_executor_id=expected_external_executor_id
)
# State should be running, and try_number column should be incremented
assert ti_from_deserialized_task.external_executor_id == expected_external_executor_id
assert ti_from_deserialized_task.state == State.RUNNING
assert ti_from_deserialized_task._try_number == 1

Expand All @@ -1817,6 +1861,7 @@ def test_check_and_change_state_before_execution_dep_not_met_already_running(sel

assert not ti_from_deserialized_task.check_and_change_state_before_execution()
assert ti_from_deserialized_task.state == State.RUNNING
assert ti_from_deserialized_task.external_executor_id is None

def test_check_and_change_state_before_execution_dep_not_met_not_runnable_state(
self, create_task_instance
Expand Down

0 comments on commit e45c12f

Please sign in to comment.