Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor _manage_executor_state by refreshing TIs in batch #36502

Merged
merged 2 commits into from
Dec 30, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
42 changes: 36 additions & 6 deletions airflow/jobs/backfill_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import attr
import pendulum
from sqlalchemy import select, update
from sqlalchemy import select, tuple_, update
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm.session import make_transient
from tabulate import tabulate
Expand Down Expand Up @@ -264,16 +264,46 @@ def _manage_executor_state(
:return: An iterable of expanded TaskInstance per MappedTask
"""
executor = self.job.executor
# list of tuples (dag_id, task_id, execution_date, map_index) of running tasks in executor
buffered_events = list(executor.get_event_buffer().items())
if session.get_bind().dialect.name == "mssql":
# SQL Server doesn't support multiple column subqueries
potiuk marked this conversation as resolved.
Show resolved Hide resolved
# TODO: Remove this once we drop support for SQL Server (#35868)
need_refresh = True
running_dict = {(ti.dag_id, ti.task_id, ti.run_id, ti.map_index): ti for ti in running.values()}
else:
running_tis_ids = [
(key.dag_id, key.task_id, key.run_id, key.map_index)
for key, _ in buffered_events
if key in running
]
# list of TaskInstance of running tasks in executor (refreshed from db in batch)
refreshed_running_tis = session.scalars(
select(TaskInstance).where(
tuple_(
TaskInstance.dag_id,
TaskInstance.task_id,
TaskInstance.run_id,
TaskInstance.map_index,
).in_(running_tis_ids)
)
).all()
# dict of refreshed TaskInstance by key to easily find them
running_dict = {
(ti.dag_id, ti.task_id, ti.run_id, ti.map_index): ti for ti in refreshed_running_tis
}
need_refresh = False

# TODO: query all instead of refresh from db
for key, value in list(executor.get_event_buffer().items()):
for key, value in buffered_events:
state, info = value
if key not in running:
ti_key = (key.dag_id, key.task_id, key.run_id, key.map_index)
if ti_key not in running_dict:
self.log.warning("%s state %s not in running=%s", key, state, running.values())
continue

ti = running[key]
ti.refresh_from_db()
ti = running_dict[ti_key]
if need_refresh:
ti.refresh_from_db(session=session)

self.log.debug("Executor state: %s task %s", state, ti)

Expand Down