diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py b/airflow-core/src/airflow/jobs/scheduler_job_runner.py index ba39e44030cb1..cd4dbe82888f7 100644 --- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py +++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py @@ -333,6 +333,7 @@ def __init__( self._parallelism = conf.getint("core", "parallelism") self._multi_team = conf.getboolean("core", "multi_team") self._max_partition_dag_runs_per_loop = MAX_PARTITION_DAG_RUNS_PER_LOOP + self._dag_id_to_team_name: dict[str, str | None] = {} self.executors: list[BaseExecutor] = executors if executors else ExecutorLoader.init_executors() self.executor: BaseExecutor = self.executors[0] @@ -383,9 +384,11 @@ def _get_team_names_for_dag_ids( self, dag_ids: Collection[str], session: Session ) -> dict[str, str | None]: """ - Batch query to resolve team names for multiple DAG IDs using the DAG > Bundle > Team relationship chain. + Resolve team names for DAG IDs via the DAG → Bundle → Team relationship. - DAG IDs > DagModel (via dag_id) > DagBundleModel (via bundle_name) > Team + Results are cached for the lifetime of the scheduler process since this is called + on every heartbeat (for metrics tagging) with largely overlapping dag_id sets, and + the underlying team assignment only changes on bundle redeployment. :param dag_ids: Collection of DAG IDs to resolve team names for :param session: Database session for queries @@ -394,36 +397,35 @@ def _get_team_names_for_dag_ids( if not dag_ids: return {} - try: - # Query all team names for the given DAG IDs in a single query - query_results = session.execute( - select(DagModel.dag_id, Team.name) - .join(DagBundleModel.teams) # Join Team to DagBundleModel via association table - .join( - DagModel, DagModel.bundle_name == DagBundleModel.name - ) # Join DagBundleModel to DagModel - .where(DagModel.dag_id.in_(dag_ids)) - ).all() - - # Create mapping from results - dag_id_to_team_name = {dag_id: team_name for dag_id, team_name in query_results} - - # Ensure all requested dag_ids are in the result (with None for those not found) - result = {dag_id: dag_id_to_team_name.get(dag_id) for dag_id in dag_ids} + missing = [dag_id for dag_id in dag_ids if dag_id not in self._dag_id_to_team_name] + if missing: + try: + # Query all team names for the given DAG IDs in a single query + query_results = session.execute( + select(DagModel.dag_id, Team.name) + .join(DagBundleModel.teams) # Join Team to DagBundleModel via association table + .join( + DagModel, DagModel.bundle_name == DagBundleModel.name + ) # Join DagBundleModel to DagModel + .where(DagModel.dag_id.in_(missing)) + ).all() + + # Create mapping from results + queried = {dag_id: team_name for dag_id, team_name in query_results} + + # Cache all results, including None for dag_ids with no team + for dag_id in missing: + self._dag_id_to_team_name[dag_id] = queried.get(dag_id) + self.log.debug("Cached team names for %d new dag_ids", len(missing)) - self.log.debug( - "Resolved team names for %d DAGs: %s", - len([team for team in result.values() if team is not None]), - {dag_id: team for dag_id, team in result.items()}, - ) + except Exception: + # Log the error, explicitly don't fail the scheduling loop + self.log.exception("Failed to resolve team names for DAG IDs: %s", missing) + # Return dict with all None values to ensure graceful degradation + return {} - return result - - except Exception: - # Log the error, explicitly don't fail the scheduling loop - self.log.exception("Failed to resolve team names for DAG IDs: %s", list(dag_ids)) - # Return dict with all None values to ensure graceful degradation - return {} + # Ensure all requested dag_ids are in the result (with None for those not found) + return {dag_id: self._dag_id_to_team_name.get(dag_id) for dag_id in dag_ids} def _get_workload_team_name(self, workload: SchedulerWorkload, session: Session) -> str | None: """ @@ -714,6 +716,9 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session) - len(unique_dag_ids), list(unique_dag_ids), ) + for ti in task_instances_to_examine: + if team := dag_id_to_team_name.get(ti.dag_id): + ti._team_name = team executor_slots_available: dict[ExecutorName, int] = {} # First get a mapping of executor names to slots they have available @@ -923,12 +928,16 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session) - loop_count, ) + starving_pool_team_mapping = ( + Pool.get_name_to_team_name_mapping(list(pool_num_starving_tasks.keys()), session=session) + if self._multi_team and pool_num_starving_tasks + else {} + ) for pool_name, num_starving_tasks in pool_num_starving_tasks.items(): - stats.gauge( - "pool.starving_tasks", - num_starving_tasks, - tags={"pool_name": pool_name}, - ) + starving_tags: dict[str, str] = {"pool_name": normalize_pool_name_for_stats(pool_name)} + if team := starving_pool_team_mapping.get(pool_name): + starving_tags["team_name"] = team + stats.gauge("pool.starving_tasks", num_starving_tasks, tags=starving_tags) stats.gauge("scheduler.tasks.starving", num_starving_tasks_total) stats.gauge("scheduler.tasks.executable", len(executable_tis)) @@ -1602,6 +1611,12 @@ def _update_dag_run_state_for_paused_dags(self, *, session: Session = NEW_SESSIO .group_by(DagRun) ) ) + if self._multi_team and paused_runs: + paused_dag_ids = {dr.dag_id for dr in paused_runs} + paused_team_mapping = self._get_team_names_for_dag_ids(paused_dag_ids, session) + for dr in paused_runs: + if team := paused_team_mapping.get(dr.dag_id): + dr._team_name = team for dag_run in paused_runs: dag = self.scheduler_dag_bag.get_dag_for_run(dag_run=dag_run, session=session) if dag is not None: @@ -1839,6 +1854,13 @@ def _do_scheduling(self, session: Session) -> int: # examining, rather than making one query per DagRun dag_runs = DagRun.get_running_dag_runs_to_examine(session=session) + if self._multi_team and dag_runs: + unique_dag_ids = {dr.dag_id for dr in dag_runs} + dr_team_mapping = self._get_team_names_for_dag_ids(unique_dag_ids, session) + for dr in dag_runs: + if team := dr_team_mapping.get(dr.dag_id): + dr._team_name = team + callback_tuples = self._schedule_all_dag_runs(guard, dag_runs, session) # Send the callbacks after we commit to ensure the context is up to date when it gets run @@ -3010,33 +3032,20 @@ def _emit_pool_metrics(self, *, session: Session = NEW_SESSION) -> None: from airflow.models.pool import Pool pools = Pool.slots_stats(session=session) + pool_team_mapping = ( + Pool.get_name_to_team_name_mapping(list(pools.keys()), session=session) + if self._multi_team + else {} + ) for pool_name, slot_stats in pools.items(): - normalized_pool_name = normalize_pool_name_for_stats(pool_name) - stats.gauge( - "pool.open_slots", - slot_stats["open"], - tags={"pool_name": normalized_pool_name}, - ) - stats.gauge( - "pool.queued_slots", - slot_stats["queued"], - tags={"pool_name": normalized_pool_name}, - ) - stats.gauge( - "pool.running_slots", - slot_stats["running"], - tags={"pool_name": normalized_pool_name}, - ) - stats.gauge( - "pool.deferred_slots", - slot_stats["deferred"], - tags={"pool_name": normalized_pool_name}, - ) - stats.gauge( - "pool.scheduled_slots", - slot_stats["scheduled"], - tags={"pool_name": normalized_pool_name}, - ) + metric_tags: dict[str, str] = {"pool_name": normalize_pool_name_for_stats(pool_name)} + if team := pool_team_mapping.get(pool_name): + metric_tags["team_name"] = team + stats.gauge("pool.open_slots", slot_stats["open"], tags=metric_tags) + stats.gauge("pool.queued_slots", slot_stats["queued"], tags=metric_tags) + stats.gauge("pool.running_slots", slot_stats["running"], tags=metric_tags) + stats.gauge("pool.deferred_slots", slot_stats["deferred"], tags=metric_tags) + stats.gauge("pool.scheduled_slots", slot_stats["scheduled"], tags=metric_tags) @provide_session def adopt_or_reset_orphaned_tasks(self, *, session: Session = NEW_SESSION) -> int: @@ -3228,6 +3237,9 @@ def _purge_task_instances_without_heartbeats( if self._multi_team: unique_dag_ids = {ti.dag_id for ti in task_instances_without_heartbeats} dag_id_to_team_name = self._get_team_names_for_dag_ids(unique_dag_ids, session) + for ti in task_instances_without_heartbeats: + if team := dag_id_to_team_name.get(ti.dag_id): + ti._team_name = team else: dag_id_to_team_name = {} diff --git a/airflow-core/src/airflow/models/dagrun.py b/airflow-core/src/airflow/models/dagrun.py index 564e11a9522d7..e3c78356d6e63 100644 --- a/airflow-core/src/airflow/models/dagrun.py +++ b/airflow-core/src/airflow/models/dagrun.py @@ -500,7 +500,11 @@ def check_version_id_exists_in_dr(self, dag_version_id: UUID, *, session: Sessio @property def stats_tags(self) -> dict[str, str]: - return prune_dict({"dag_id": self.dag_id, "run_type": self.run_type}) + tags = prune_dict({"dag_id": self.dag_id, "run_type": self.run_type}) + team_name = getattr(self, "_team_name", None) + if team_name: + tags["team_name"] = team_name + return tags def get_state(self): return self._state diff --git a/airflow-core/src/airflow/models/taskinstance.py b/airflow-core/src/airflow/models/taskinstance.py index 0d0183665fa46..271a6c06661d2 100644 --- a/airflow-core/src/airflow/models/taskinstance.py +++ b/airflow-core/src/airflow/models/taskinstance.py @@ -731,7 +731,11 @@ def __hash__(self): @property def stats_tags(self) -> dict[str, str]: """Returns task instance tags.""" - return prune_dict({"dag_id": self.dag_id, "task_id": self.task_id}) + tags = prune_dict({"dag_id": self.dag_id, "task_id": self.task_id}) + team_name = getattr(self, "_team_name", None) + if team_name: + tags["team_name"] = team_name + return tags @staticmethod def insert_mapping( diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py b/airflow-core/tests/unit/jobs/test_scheduler_job.py index 258ce3db4cff6..18cc1128da523 100644 --- a/airflow-core/tests/unit/jobs/test_scheduler_job.py +++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py @@ -2825,6 +2825,38 @@ def test_emit_pool_starving_tasks_metrics(self, mock_get_backend, dag_maker): session.rollback() session.close() + @pytest.mark.parametrize( + ("multi_team", "expected_tags"), + [ + pytest.param("true", {"pool_name": "team_pool", "team_name": "team_a"}, id="with_team"), + pytest.param("false", {"pool_name": "team_pool"}, id="without_team"), + ], + ) + @mock.patch("airflow._shared.observability.metrics.stats._get_backend") + def test_emit_pool_metrics_team_name(self, mock_get_backend, multi_team, expected_tags, session): + """Pool metrics include team_name only when multi_team is enabled.""" + mock_stats = mock.MagicMock(spec=StatsLogger) + mock_get_backend.return_value = mock_stats + + clear_db_teams() + + team = Team(name="team_a") + session.add(team) + session.flush() + + pool = Pool(pool="team_pool", slots=5, include_deferred=False, team_name="team_a") + session.add(pool) + session.flush() + + with conf_vars({("core", "multi_team"): multi_team}): + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job) + self.job_runner._emit_pool_metrics(session=session) + + mock_stats.gauge.assert_any_call("pool.open_slots", mock.ANY, tags=expected_tags) + mock_stats.gauge.assert_any_call("pool.queued_slots", mock.ANY, tags=expected_tags) + mock_stats.gauge.assert_any_call("pool.running_slots", mock.ANY, tags=expected_tags) + def test_enqueue_task_instances_with_queued_state(self, dag_maker, session): dag_id = "SchedulerJobTest.test_enqueue_task_instances_with_queued_state" task_id_1 = "dummy" @@ -9171,6 +9203,41 @@ def test_multi_team_config_disabled_uses_legacy_behavior(self, dag_maker, mock_e assert result1 == self.job_runner.executor # Default for no explicit executor assert result2 == mock_executors[1] # Matched by executor name + @conf_vars({("core", "multi_team"): "true"}) + def test_multi_team_sets_team_name_on_task_instances(self, dag_maker, mock_executors, session): + """Test that _team_name is set on TaskInstance objects during the scheduling loop.""" + clear_db_teams() + clear_db_dag_bundles() + + team = Team(name="team_a") + session.add(team) + session.flush() + + bundle = DagBundleModel(name="bundle_a") + bundle.teams.append(team) + session.add(bundle) + session.flush() + + with dag_maker(dag_id="dag_a", bundle_name="bundle_a", session=session): + EmptyOperator(task_id="task_a") + + dr = dag_maker.create_dagrun() + ti = dr.get_task_instance("task_a", session=session) + ti.state = State.SCHEDULED + session.flush() + + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job) + self.job_runner._multi_team = True + + # Simulate what _executable_task_instances_to_queued does + dag_id_to_team_name = self.job_runner._get_team_names_for_dag_ids(["dag_a"], session) + if team_name := dag_id_to_team_name.get(ti.dag_id): + ti._team_name = team_name + + assert ti._team_name == "team_a" + assert ti.stats_tags == {"dag_id": "dag_a", "task_id": "task_a", "team_name": "team_a"} + @pytest.mark.need_serialized_dag def test_schedule_dag_run_with_upstream_skip(dag_maker, session): diff --git a/airflow-core/tests/unit/models/test_dagrun.py b/airflow-core/tests/unit/models/test_dagrun.py index 1154a19617e3e..b2c6bb2742db5 100644 --- a/airflow-core/tests/unit/models/test_dagrun.py +++ b/airflow-core/tests/unit/models/test_dagrun.py @@ -3851,3 +3851,33 @@ def test_context_carrier_includes_detail_level_from_conf(self, dag_maker): span = trace.get_current_span(ctx) assert get_task_span_detail_level(span) == 2 + + +class TestDagRunStatsTagsTeamName: + def test_stats_tags_without_team_name(self, dag_maker): + """stats_tags should not include team_name when _team_name is not set.""" + with dag_maker("test_dag"): + EmptyOperator(task_id="t1") + dr = dag_maker.create_dagrun() + tags = dr.stats_tags + assert "team_name" not in tags + assert tags == {"dag_id": "test_dag", "run_type": "manual"} + + def test_stats_tags_with_team_name(self, dag_maker): + """stats_tags should include team_name when _team_name is set.""" + with dag_maker("test_dag"): + EmptyOperator(task_id="t1") + dr = dag_maker.create_dagrun() + dr._team_name = "my_team" + tags = dr.stats_tags + assert tags["team_name"] == "my_team" + assert tags == {"dag_id": "test_dag", "run_type": "manual", "team_name": "my_team"} + + def test_stats_tags_with_none_team_name(self, dag_maker): + """stats_tags should not include team_name when _team_name is None.""" + with dag_maker("test_dag"): + EmptyOperator(task_id="t1") + dr = dag_maker.create_dagrun() + dr._team_name = None + tags = dr.stats_tags + assert "team_name" not in tags diff --git a/airflow-core/tests/unit/models/test_taskinstance.py b/airflow-core/tests/unit/models/test_taskinstance.py index 8e9212d51d6d6..7f2401bc37ad8 100644 --- a/airflow-core/tests/unit/models/test_taskinstance.py +++ b/airflow-core/tests/unit/models/test_taskinstance.py @@ -4068,3 +4068,36 @@ def test_task_instance_repr_does_not_raise_for_deferred_columns(dag_maker, sessi assert "" in result assert "[queued]" not in result + + +class TestTaskInstanceStatsTagsTeamName: + def test_stats_tags_without_team_name(self, dag_maker, session): + """stats_tags should not include team_name when _team_name is not set.""" + with dag_maker("test_dag"): + EmptyOperator(task_id="my_task") + dr = dag_maker.create_dagrun() + ti = dr.get_task_instance("my_task", session=session) + tags = ti.stats_tags + assert "team_name" not in tags + assert tags == {"dag_id": "test_dag", "task_id": "my_task"} + + def test_stats_tags_with_team_name(self, dag_maker, session): + """stats_tags should include team_name when _team_name is set.""" + with dag_maker("test_dag"): + EmptyOperator(task_id="my_task") + dr = dag_maker.create_dagrun() + ti = dr.get_task_instance("my_task", session=session) + ti._team_name = "my_team" + tags = ti.stats_tags + assert tags["team_name"] == "my_team" + assert tags == {"dag_id": "test_dag", "task_id": "my_task", "team_name": "my_team"} + + def test_stats_tags_with_none_team_name(self, dag_maker, session): + """stats_tags should not include team_name when _team_name is None.""" + with dag_maker("test_dag"): + EmptyOperator(task_id="my_task") + dr = dag_maker.create_dagrun() + ti = dr.get_task_instance("my_task", session=session) + ti._team_name = None + tags = ti.stats_tags + assert "team_name" not in tags diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 6f48756433bf1..b34421c78b8f8 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -249,6 +249,14 @@ class RuntimeTaskInstance(TaskInstance): sentry_integration: str = "" + @property + def stats_tags(self) -> dict[str, str]: + """Metric tags for this task instance, including team_name when available.""" + tags: dict[str, str] = {"dag_id": self.dag_id, "task_id": self.task_id} + if self._ti_context_from_server and self._ti_context_from_server.dag_run.team_name: + tags["team_name"] = self._ti_context_from_server.dag_run.team_name + return tags + def __rich_repr__(self): yield "id", self.id yield "task_id", self.task_id @@ -1408,7 +1416,7 @@ def _on_term(signum, frame): state: TaskInstanceState | None = None error: BaseException | None = None - stats_tags = {"dag_id": ti.dag_id, "task_id": ti.task_id} + stats_tags = ti.stats_tags stats.incr("ti.start", tags=stats_tags) try: @@ -1584,7 +1592,7 @@ def _handle_current_task_success( # Record operator and task instance success metrics operator = ti.task.__class__.__name__ - stats_tags = {"dag_id": ti.dag_id, "task_id": ti.task_id} + stats_tags = ti.stats_tags stats.incr("operator_successes", tags={**stats_tags, "operator_name": operator}) stats.incr("ti_successes", tags=stats_tags) @@ -1683,7 +1691,7 @@ def _handle_current_task_failed( # Record operator and task instance failed metrics operator = ti.task.__class__.__name__ - stats_tags = {"dag_id": ti.dag_id, "task_id": ti.task_id} + stats_tags = ti.stats_tags stats.incr("operator_failures", tags={**stats_tags, "operator_name": operator}) stats.incr("ti_failures", tags=stats_tags) @@ -2044,9 +2052,7 @@ def finalize( # Record task duration metrics for all terminal states if ti.start_date and ti.end_date: duration_ms = (ti.end_date - ti.start_date).total_seconds() * 1000 - stats_tags = {"dag_id": ti.dag_id, "task_id": ti.task_id} - - stats.timing("task.duration", duration_ms, tags=stats_tags) + stats.timing("task.duration", duration_ms, tags=ti.stats_tags) task = ti.task # Pushing xcom for each operator extra links defined on the operator only. diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index eda5ea9d9cddf..61ffa722304d3 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -5216,6 +5216,55 @@ def test_operator_failures_metrics_emitted(self, create_runtime_ti, mock_supervi ) backend.incr.assert_any_call("ti_failures", tags=stats_tags) + @pytest.mark.parametrize( + ("team_name", "expected_tags_extra"), + [ + pytest.param("my_team", {"team_name": "my_team"}, id="with_team"), + pytest.param(None, {}, id="without_team"), + ], + ) + def test_ti_start_metric_respects_team_name( + self, team_name, expected_tags_extra, create_runtime_ti, mock_supervisor_comms + ): + task = PythonOperator(task_id="test", python_callable=lambda: "success") + ti = create_runtime_ti(task=task) + if team_name: + ti._ti_context_from_server.dag_run.team_name = team_name + + with mock.patch("airflow.sdk._shared.observability.metrics.stats._get_backend") as mock_get_backend: + backend = mock.MagicMock(spec=StatsLogger) + mock_get_backend.return_value = backend + run(ti, context=ti.get_template_context(), log=mock.MagicMock()) + + expected = {"dag_id": ti.dag_id, "task_id": ti.task_id, **expected_tags_extra} + backend.incr.assert_any_call("ti.start", tags=expected) + + @pytest.mark.parametrize( + ("task_callable", "operator_metric", "ti_metric"), + [ + pytest.param(lambda: "success", "operator_successes", "ti_successes", id="success"), + pytest.param(lambda: 1 / 0, "operator_failures", "ti_failures", id="failure"), + ], + ) + def test_operator_metrics_respect_team_name( + self, task_callable, operator_metric, ti_metric, create_runtime_ti, mock_supervisor_comms + ): + task = PythonOperator(task_id="test", python_callable=task_callable) + ti = create_runtime_ti(task=task) + ti._ti_context_from_server.dag_run.team_name = "team_a" + + with mock.patch("airflow.sdk._shared.observability.metrics.stats._get_backend") as mock_get_backend: + backend = mock.MagicMock(spec=StatsLogger) + mock_get_backend.return_value = backend + run(ti, context=ti.get_template_context(), log=mock.MagicMock()) + + stats_tags = {"dag_id": ti.dag_id, "task_id": ti.task_id, "team_name": "team_a"} + backend.incr.assert_any_call( + operator_metric, + tags={**stats_tags, "operator_name": "PythonOperator"}, + ) + backend.incr.assert_any_call(ti_metric, tags=stats_tags) + class TestDetailSpan: """Tests for the detail_span decorator / context manager."""