Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 73 additions & 61 deletions airflow-core/src/airflow/jobs/scheduler_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@ def __init__(
self._parallelism = conf.getint("core", "parallelism")
self._multi_team = conf.getboolean("core", "multi_team")
self._max_partition_dag_runs_per_loop = MAX_PARTITION_DAG_RUNS_PER_LOOP
self._dag_id_to_team_name: dict[str, str | None] = {}

self.executors: list[BaseExecutor] = executors if executors else ExecutorLoader.init_executors()
self.executor: BaseExecutor = self.executors[0]
Expand Down Expand Up @@ -383,9 +384,11 @@ def _get_team_names_for_dag_ids(
self, dag_ids: Collection[str], session: Session
) -> dict[str, str | None]:
"""
Batch query to resolve team names for multiple DAG IDs using the DAG > Bundle > Team relationship chain.
Resolve team names for DAG IDs via the DAG Bundle Team relationship.

DAG IDs > DagModel (via dag_id) > DagBundleModel (via bundle_name) > Team
Results are cached for the lifetime of the scheduler process since this is called
on every heartbeat (for metrics tagging) with largely overlapping dag_id sets, and
the underlying team assignment only changes on bundle redeployment.

:param dag_ids: Collection of DAG IDs to resolve team names for
:param session: Database session for queries
Expand All @@ -394,36 +397,35 @@ def _get_team_names_for_dag_ids(
if not dag_ids:
return {}

try:
# Query all team names for the given DAG IDs in a single query
query_results = session.execute(
select(DagModel.dag_id, Team.name)
.join(DagBundleModel.teams) # Join Team to DagBundleModel via association table
.join(
DagModel, DagModel.bundle_name == DagBundleModel.name
) # Join DagBundleModel to DagModel
.where(DagModel.dag_id.in_(dag_ids))
).all()

# Create mapping from results
dag_id_to_team_name = {dag_id: team_name for dag_id, team_name in query_results}

# Ensure all requested dag_ids are in the result (with None for those not found)
result = {dag_id: dag_id_to_team_name.get(dag_id) for dag_id in dag_ids}
missing = [dag_id for dag_id in dag_ids if dag_id not in self._dag_id_to_team_name]
if missing:
try:
# Query all team names for the given DAG IDs in a single query
query_results = session.execute(
select(DagModel.dag_id, Team.name)
.join(DagBundleModel.teams) # Join Team to DagBundleModel via association table
.join(
DagModel, DagModel.bundle_name == DagBundleModel.name
) # Join DagBundleModel to DagModel
.where(DagModel.dag_id.in_(missing))
).all()

# Create mapping from results
queried = {dag_id: team_name for dag_id, team_name in query_results}

# Cache all results, including None for dag_ids with no team
for dag_id in missing:
self._dag_id_to_team_name[dag_id] = queried.get(dag_id)
self.log.debug("Cached team names for %d new dag_ids", len(missing))

self.log.debug(
"Resolved team names for %d DAGs: %s",
len([team for team in result.values() if team is not None]),
{dag_id: team for dag_id, team in result.items()},
)
except Exception:
# Log the error, explicitly don't fail the scheduling loop
self.log.exception("Failed to resolve team names for DAG IDs: %s", missing)
# Return dict with all None values to ensure graceful degradation
return {}

return result

except Exception:
# Log the error, explicitly don't fail the scheduling loop
self.log.exception("Failed to resolve team names for DAG IDs: %s", list(dag_ids))
# Return dict with all None values to ensure graceful degradation
return {}
# Ensure all requested dag_ids are in the result (with None for those not found)
return {dag_id: self._dag_id_to_team_name.get(dag_id) for dag_id in dag_ids}

def _get_workload_team_name(self, workload: SchedulerWorkload, session: Session) -> str | None:
"""
Expand Down Expand Up @@ -714,6 +716,9 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session) -
len(unique_dag_ids),
list(unique_dag_ids),
)
for ti in task_instances_to_examine:
if team := dag_id_to_team_name.get(ti.dag_id):
ti._team_name = team

executor_slots_available: dict[ExecutorName, int] = {}
# First get a mapping of executor names to slots they have available
Expand Down Expand Up @@ -923,12 +928,16 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session) -
loop_count,
)

starving_pool_team_mapping = (
Pool.get_name_to_team_name_mapping(list(pool_num_starving_tasks.keys()), session=session)
if self._multi_team and pool_num_starving_tasks
else {}
)
for pool_name, num_starving_tasks in pool_num_starving_tasks.items():
stats.gauge(
"pool.starving_tasks",
num_starving_tasks,
tags={"pool_name": pool_name},
)
starving_tags: dict[str, str] = {"pool_name": normalize_pool_name_for_stats(pool_name)}
if team := starving_pool_team_mapping.get(pool_name):
starving_tags["team_name"] = team
stats.gauge("pool.starving_tasks", num_starving_tasks, tags=starving_tags)

stats.gauge("scheduler.tasks.starving", num_starving_tasks_total)
stats.gauge("scheduler.tasks.executable", len(executable_tis))
Expand Down Expand Up @@ -1602,6 +1611,12 @@ def _update_dag_run_state_for_paused_dags(self, *, session: Session = NEW_SESSIO
.group_by(DagRun)
)
)
if self._multi_team and paused_runs:
paused_dag_ids = {dr.dag_id for dr in paused_runs}
paused_team_mapping = self._get_team_names_for_dag_ids(paused_dag_ids, session)
for dr in paused_runs:
if team := paused_team_mapping.get(dr.dag_id):
dr._team_name = team
for dag_run in paused_runs:
dag = self.scheduler_dag_bag.get_dag_for_run(dag_run=dag_run, session=session)
if dag is not None:
Expand Down Expand Up @@ -1839,6 +1854,13 @@ def _do_scheduling(self, session: Session) -> int:
# examining, rather than making one query per DagRun
dag_runs = DagRun.get_running_dag_runs_to_examine(session=session)

if self._multi_team and dag_runs:
unique_dag_ids = {dr.dag_id for dr in dag_runs}
dr_team_mapping = self._get_team_names_for_dag_ids(unique_dag_ids, session)
for dr in dag_runs:
if team := dr_team_mapping.get(dr.dag_id):
dr._team_name = team

callback_tuples = self._schedule_all_dag_runs(guard, dag_runs, session)

# Send the callbacks after we commit to ensure the context is up to date when it gets run
Expand Down Expand Up @@ -3010,33 +3032,20 @@ def _emit_pool_metrics(self, *, session: Session = NEW_SESSION) -> None:
from airflow.models.pool import Pool

pools = Pool.slots_stats(session=session)
pool_team_mapping = (
Pool.get_name_to_team_name_mapping(list(pools.keys()), session=session)
if self._multi_team
else {}
)
for pool_name, slot_stats in pools.items():
normalized_pool_name = normalize_pool_name_for_stats(pool_name)
stats.gauge(
"pool.open_slots",
slot_stats["open"],
tags={"pool_name": normalized_pool_name},
)
stats.gauge(
"pool.queued_slots",
slot_stats["queued"],
tags={"pool_name": normalized_pool_name},
)
stats.gauge(
"pool.running_slots",
slot_stats["running"],
tags={"pool_name": normalized_pool_name},
)
stats.gauge(
"pool.deferred_slots",
slot_stats["deferred"],
tags={"pool_name": normalized_pool_name},
)
stats.gauge(
"pool.scheduled_slots",
slot_stats["scheduled"],
tags={"pool_name": normalized_pool_name},
)
metric_tags: dict[str, str] = {"pool_name": normalize_pool_name_for_stats(pool_name)}
if team := pool_team_mapping.get(pool_name):
metric_tags["team_name"] = team
stats.gauge("pool.open_slots", slot_stats["open"], tags=metric_tags)
stats.gauge("pool.queued_slots", slot_stats["queued"], tags=metric_tags)
stats.gauge("pool.running_slots", slot_stats["running"], tags=metric_tags)
stats.gauge("pool.deferred_slots", slot_stats["deferred"], tags=metric_tags)
stats.gauge("pool.scheduled_slots", slot_stats["scheduled"], tags=metric_tags)

@provide_session
def adopt_or_reset_orphaned_tasks(self, *, session: Session = NEW_SESSION) -> int:
Expand Down Expand Up @@ -3228,6 +3237,9 @@ def _purge_task_instances_without_heartbeats(
if self._multi_team:
unique_dag_ids = {ti.dag_id for ti in task_instances_without_heartbeats}
dag_id_to_team_name = self._get_team_names_for_dag_ids(unique_dag_ids, session)
for ti in task_instances_without_heartbeats:
if team := dag_id_to_team_name.get(ti.dag_id):
ti._team_name = team
else:
dag_id_to_team_name = {}

Expand Down
6 changes: 5 additions & 1 deletion airflow-core/src/airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,11 @@ def check_version_id_exists_in_dr(self, dag_version_id: UUID, *, session: Sessio

@property
def stats_tags(self) -> dict[str, str]:
return prune_dict({"dag_id": self.dag_id, "run_type": self.run_type})
tags = prune_dict({"dag_id": self.dag_id, "run_type": self.run_type})
team_name = getattr(self, "_team_name", None)
if team_name:
tags["team_name"] = team_name
return tags

def get_state(self):
return self._state
Expand Down
6 changes: 5 additions & 1 deletion airflow-core/src/airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,7 +731,11 @@ def __hash__(self):
@property
def stats_tags(self) -> dict[str, str]:
"""Returns task instance tags."""
return prune_dict({"dag_id": self.dag_id, "task_id": self.task_id})
tags = prune_dict({"dag_id": self.dag_id, "task_id": self.task_id})
team_name = getattr(self, "_team_name", None)
if team_name:
tags["team_name"] = team_name
return tags

@staticmethod
def insert_mapping(
Expand Down
67 changes: 67 additions & 0 deletions airflow-core/tests/unit/jobs/test_scheduler_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -2825,6 +2825,38 @@ def test_emit_pool_starving_tasks_metrics(self, mock_get_backend, dag_maker):
session.rollback()
session.close()

@pytest.mark.parametrize(
("multi_team", "expected_tags"),
[
pytest.param("true", {"pool_name": "team_pool", "team_name": "team_a"}, id="with_team"),
pytest.param("false", {"pool_name": "team_pool"}, id="without_team"),
],
)
@mock.patch("airflow._shared.observability.metrics.stats._get_backend")
def test_emit_pool_metrics_team_name(self, mock_get_backend, multi_team, expected_tags, session):
"""Pool metrics include team_name only when multi_team is enabled."""
mock_stats = mock.MagicMock(spec=StatsLogger)
mock_get_backend.return_value = mock_stats

clear_db_teams()

team = Team(name="team_a")
session.add(team)
session.flush()

pool = Pool(pool="team_pool", slots=5, include_deferred=False, team_name="team_a")
session.add(pool)
session.flush()

with conf_vars({("core", "multi_team"): multi_team}):
scheduler_job = Job()
self.job_runner = SchedulerJobRunner(job=scheduler_job)
self.job_runner._emit_pool_metrics(session=session)

mock_stats.gauge.assert_any_call("pool.open_slots", mock.ANY, tags=expected_tags)
mock_stats.gauge.assert_any_call("pool.queued_slots", mock.ANY, tags=expected_tags)
mock_stats.gauge.assert_any_call("pool.running_slots", mock.ANY, tags=expected_tags)

def test_enqueue_task_instances_with_queued_state(self, dag_maker, session):
dag_id = "SchedulerJobTest.test_enqueue_task_instances_with_queued_state"
task_id_1 = "dummy"
Expand Down Expand Up @@ -9171,6 +9203,41 @@ def test_multi_team_config_disabled_uses_legacy_behavior(self, dag_maker, mock_e
assert result1 == self.job_runner.executor # Default for no explicit executor
assert result2 == mock_executors[1] # Matched by executor name

@conf_vars({("core", "multi_team"): "true"})
def test_multi_team_sets_team_name_on_task_instances(self, dag_maker, mock_executors, session):
"""Test that _team_name is set on TaskInstance objects during the scheduling loop."""
clear_db_teams()
clear_db_dag_bundles()

team = Team(name="team_a")
session.add(team)
session.flush()

bundle = DagBundleModel(name="bundle_a")
bundle.teams.append(team)
session.add(bundle)
session.flush()

with dag_maker(dag_id="dag_a", bundle_name="bundle_a", session=session):
EmptyOperator(task_id="task_a")

dr = dag_maker.create_dagrun()
ti = dr.get_task_instance("task_a", session=session)
ti.state = State.SCHEDULED
session.flush()

scheduler_job = Job()
self.job_runner = SchedulerJobRunner(job=scheduler_job)
self.job_runner._multi_team = True

# Simulate what _executable_task_instances_to_queued does
dag_id_to_team_name = self.job_runner._get_team_names_for_dag_ids(["dag_a"], session)
if team_name := dag_id_to_team_name.get(ti.dag_id):
ti._team_name = team_name

assert ti._team_name == "team_a"
assert ti.stats_tags == {"dag_id": "dag_a", "task_id": "task_a", "team_name": "team_a"}


@pytest.mark.need_serialized_dag
def test_schedule_dag_run_with_upstream_skip(dag_maker, session):
Expand Down
30 changes: 30 additions & 0 deletions airflow-core/tests/unit/models/test_dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -3851,3 +3851,33 @@ def test_context_carrier_includes_detail_level_from_conf(self, dag_maker):

span = trace.get_current_span(ctx)
assert get_task_span_detail_level(span) == 2


class TestDagRunStatsTagsTeamName:
def test_stats_tags_without_team_name(self, dag_maker):
"""stats_tags should not include team_name when _team_name is not set."""
with dag_maker("test_dag"):
EmptyOperator(task_id="t1")
dr = dag_maker.create_dagrun()
tags = dr.stats_tags
assert "team_name" not in tags
assert tags == {"dag_id": "test_dag", "run_type": "manual"}

def test_stats_tags_with_team_name(self, dag_maker):
"""stats_tags should include team_name when _team_name is set."""
with dag_maker("test_dag"):
EmptyOperator(task_id="t1")
dr = dag_maker.create_dagrun()
dr._team_name = "my_team"
tags = dr.stats_tags
assert tags["team_name"] == "my_team"
assert tags == {"dag_id": "test_dag", "run_type": "manual", "team_name": "my_team"}

def test_stats_tags_with_none_team_name(self, dag_maker):
"""stats_tags should not include team_name when _team_name is None."""
with dag_maker("test_dag"):
EmptyOperator(task_id="t1")
dr = dag_maker.create_dagrun()
dr._team_name = None
tags = dr.stats_tags
assert "team_name" not in tags
33 changes: 33 additions & 0 deletions airflow-core/tests/unit/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -4068,3 +4068,36 @@ def test_task_instance_repr_does_not_raise_for_deferred_columns(dag_maker, sessi

assert "<deferred>" in result
assert "[queued]" not in result


class TestTaskInstanceStatsTagsTeamName:
def test_stats_tags_without_team_name(self, dag_maker, session):
"""stats_tags should not include team_name when _team_name is not set."""
with dag_maker("test_dag"):
EmptyOperator(task_id="my_task")
dr = dag_maker.create_dagrun()
ti = dr.get_task_instance("my_task", session=session)
tags = ti.stats_tags
assert "team_name" not in tags
assert tags == {"dag_id": "test_dag", "task_id": "my_task"}

def test_stats_tags_with_team_name(self, dag_maker, session):
"""stats_tags should include team_name when _team_name is set."""
with dag_maker("test_dag"):
EmptyOperator(task_id="my_task")
dr = dag_maker.create_dagrun()
ti = dr.get_task_instance("my_task", session=session)
ti._team_name = "my_team"
tags = ti.stats_tags
assert tags["team_name"] == "my_team"
assert tags == {"dag_id": "test_dag", "task_id": "my_task", "team_name": "my_team"}

def test_stats_tags_with_none_team_name(self, dag_maker, session):
"""stats_tags should not include team_name when _team_name is None."""
with dag_maker("test_dag"):
EmptyOperator(task_id="my_task")
dr = dag_maker.create_dagrun()
ti = dr.get_task_instance("my_task", session=session)
ti._team_name = None
tags = ti.stats_tags
assert "team_name" not in tags
Loading