Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Experimental: Support custom weight_rule implementation to calculate the TI priority_weight #38222

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 9 additions & 4 deletions airflow/api_connexion/schemas/common_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@

from airflow.models.mappedoperator import MappedOperator
from airflow.serialization.serialized_objects import SerializedBaseOperator
from airflow.utils.weight_rule import WeightRule


class CronExpression(typing.NamedTuple):
Expand Down Expand Up @@ -138,9 +137,15 @@ def __init__(self, **metadata):
class WeightRuleField(fields.String):
"""Schema for WeightRule."""

def __init__(self, **metadata):
super().__init__(**metadata)
self.validators = [validate.OneOf(WeightRule.all_weight_rules()), *self.validators]
def _serialize(self, value, attr, obj, **kwargs):
from airflow.serialization.serialized_objects import encode_priority_weight_strategy

return encode_priority_weight_strategy(value)

def _deserialize(self, value, attr, data, **kwargs):
from airflow.serialization.serialized_objects import decode_priority_weight_strategy

return decode_priority_weight_strategy(value)


class TimezoneField(fields.String):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING

from airflow.plugins_manager import AirflowPlugin
from airflow.task.priority_strategy import PriorityWeightStrategy

if TYPE_CHECKING:
from airflow.models import TaskInstance


# [START custom_priority_weight_strategy]
class DecreasingPriorityStrategy(PriorityWeightStrategy):
"""A priority weight strategy that decreases the priority weight with each attempt of the DAG task."""

def get_weight(self, ti: TaskInstance):
return max(3 - ti._try_number + 1, 1)


class DecreasingPriorityWeightStrategyPlugin(AirflowPlugin):
name = "decreasing_priority_weight_strategy_plugin"
priority_weight_strategies = [DecreasingPriorityStrategy]


# [END custom_priority_weight_strategy]
2 changes: 1 addition & 1 deletion airflow/executors/base_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def queue_task_instance(
self.queue_command(
task_instance,
command_list_to_run,
priority=task_instance.task.priority_weight_total,
priority=task_instance.priority_weight,
queue=task_instance.task.queue,
)

Expand Down
2 changes: 1 addition & 1 deletion airflow/executors/debug_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def queue_task_instance(
self.queue_command(
task_instance,
[str(task_instance)], # Just for better logging, it's not used anywhere
priority=task_instance.task.priority_weight_total,
priority=task_instance.priority_weight,
queue=task_instance.task.queue,
)
# Save params for TaskInstance._run_raw_task
Expand Down
15 changes: 11 additions & 4 deletions airflow/models/abstractoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from airflow.models.mappedoperator import MappedOperator
from airflow.models.operator import Operator
from airflow.models.taskinstance import TaskInstance
from airflow.task.priority_strategy import PriorityWeightStrategy
from airflow.utils.task_group import TaskGroup

DEFAULT_OWNER: str = conf.get_mandatory_value("operators", "default_owner")
Expand Down Expand Up @@ -97,7 +98,7 @@ class AbstractOperator(Templater, DAGNode):

operator_class: type[BaseOperator] | dict[str, Any]

weight_rule: str
weight_rule: PriorityWeightStrategy
priority_weight: int

# Defines the operator level extra links.
Expand Down Expand Up @@ -397,11 +398,17 @@ def priority_weight_total(self) -> int:
- WeightRule.DOWNSTREAM - adds priority weight of all downstream tasks
- WeightRule.UPSTREAM - adds priority weight of all upstream tasks
"""
if self.weight_rule == WeightRule.ABSOLUTE:
from airflow.task.priority_strategy import (
_AbsolutePriorityWeightStrategy,
_DownstreamPriorityWeightStrategy,
_UpstreamPriorityWeightStrategy,
)

if type(self.weight_rule) == _AbsolutePriorityWeightStrategy:
return self.priority_weight
elif self.weight_rule == WeightRule.DOWNSTREAM:
elif type(self.weight_rule) == _DownstreamPriorityWeightStrategy:
upstream = False
elif self.weight_rule == WeightRule.UPSTREAM:
elif type(self.weight_rule) == _UpstreamPriorityWeightStrategy:
upstream = True
else:
upstream = False
Expand Down
21 changes: 11 additions & 10 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
from airflow.models.taskinstance import TaskInstance, clear_task_instances
from airflow.models.taskmixin import DependencyMixin
from airflow.serialization.enums import DagAttributeTypes
from airflow.task.priority_strategy import PriorityWeightStrategy, validate_and_load_priority_weight_strategy
from airflow.ti_deps.deps.not_in_retry_period_dep import NotInRetryPeriodDep
from airflow.ti_deps.deps.not_previously_skipped_dep import NotPreviouslySkippedDep
from airflow.ti_deps.deps.prev_dagrun_dep import PrevDagrunDep
Expand All @@ -94,7 +95,6 @@
from airflow.utils.setup_teardown import SetupTeardownContext
from airflow.utils.trigger_rule import TriggerRule
from airflow.utils.types import NOTSET
from airflow.utils.weight_rule import WeightRule
from airflow.utils.xcom import XCOM_RETURN_KEY

if TYPE_CHECKING:
Expand Down Expand Up @@ -244,7 +244,7 @@ def partial(
retry_delay: timedelta | float | ArgNotSet = NOTSET,
retry_exponential_backoff: bool | ArgNotSet = NOTSET,
priority_weight: int | ArgNotSet = NOTSET,
weight_rule: str | ArgNotSet = NOTSET,
weight_rule: str | PriorityWeightStrategy | ArgNotSet = NOTSET,
sla: timedelta | None | ArgNotSet = NOTSET,
map_index_template: str | None | ArgNotSet = NOTSET,
max_active_tis_per_dag: int | None | ArgNotSet = NOTSET,
Expand Down Expand Up @@ -575,6 +575,13 @@ class derived from this one results in the creation of a task object,
significantly speeding up the task creation process as for very large
DAGs. Options can be set as string or using the constants defined in
the static class ``airflow.utils.WeightRule``
|experimental|
Since 2.9.0, Airflow allows to define custom priority weight strategy,
by creating a subclass of
``airflow.task.priority_strategy.PriorityWeightStrategy`` and registering
in a plugin, then providing the class path or the class instance via
``weight_rule`` parameter. The custom priority weight strategy will be
used to calculate the effective total priority weight of the task instance.
:param queue: which queue to target when running this job. Not
all executors implement queue management, the CeleryExecutor
does support targeting specific queues.
Expand Down Expand Up @@ -767,7 +774,7 @@ def __init__(
params: collections.abc.MutableMapping | None = None,
default_args: dict | None = None,
priority_weight: int = DEFAULT_PRIORITY_WEIGHT,
weight_rule: str = DEFAULT_WEIGHT_RULE,
weight_rule: str | PriorityWeightStrategy = DEFAULT_WEIGHT_RULE,
queue: str = DEFAULT_QUEUE,
pool: str | None = None,
pool_slots: int = DEFAULT_POOL_SLOTS,
Expand Down Expand Up @@ -918,13 +925,7 @@ def __init__(
f"received '{type(priority_weight)}'."
)
self.priority_weight = priority_weight
if not WeightRule.is_valid(weight_rule):
raise AirflowException(
f"The weight_rule must be one of "
f"{WeightRule.all_weight_rules},'{dag.dag_id if dag else ''}.{task_id}'; "
f"received '{weight_rule}'."
)
self.weight_rule = weight_rule
self.weight_rule = validate_and_load_priority_weight_strategy(weight_rule)
self.resources = coerce_resources(resources)
if task_concurrency and not max_active_tis_per_dag:
# TODO: Remove in Airflow 3.0
Expand Down
11 changes: 7 additions & 4 deletions airflow/models/mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
)
from airflow.models.pool import Pool
from airflow.serialization.enums import DagAttributeTypes
from airflow.task.priority_strategy import PriorityWeightStrategy, validate_and_load_priority_weight_strategy
from airflow.ti_deps.deps.mapped_task_expanded import MappedTaskIsExpanded
from airflow.typing_compat import Literal
from airflow.utils.context import context_update_for_unmapped
Expand Down Expand Up @@ -534,12 +535,14 @@ def priority_weight(self, value: int) -> None:
self.partial_kwargs["priority_weight"] = value

@property
def weight_rule(self) -> str: # type: ignore[override]
return self.partial_kwargs.get("weight_rule", DEFAULT_WEIGHT_RULE)
def weight_rule(self) -> PriorityWeightStrategy: # type: ignore[override]
return validate_and_load_priority_weight_strategy(
self.partial_kwargs.get("weight_rule", DEFAULT_WEIGHT_RULE)
)

@weight_rule.setter
def weight_rule(self, value: str) -> None:
self.partial_kwargs["weight_rule"] = value
def weight_rule(self, value: str | PriorityWeightStrategy) -> None:
self.partial_kwargs["weight_rule"] = validate_and_load_priority_weight_strategy(value)

@property
def sla(self) -> datetime.timedelta | None:
Expand Down
12 changes: 10 additions & 2 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,7 +917,11 @@ def _refresh_from_task(
task_instance.queue = task.queue
task_instance.pool = pool_override or task.pool
task_instance.pool_slots = task.pool_slots
task_instance.priority_weight = task.priority_weight_total
with contextlib.suppress(Exception):
# This method is called from the different places, and sometimes the TI is not fully initialized
task_instance.priority_weight = task_instance.task.weight_rule.get_weight(
task_instance # type: ignore[arg-type]
)
task_instance.run_as_user = task.run_as_user
# Do not set max_tries to task.retries here because max_tries is a cumulative
# value that needs to be stored in the db.
Expand Down Expand Up @@ -1421,6 +1425,10 @@ def insert_mapping(run_id: str, task: Operator, map_index: int) -> dict[str, Any

:meta private:
"""
priority_weight = task.weight_rule.get_weight(
TaskInstance(task=task, run_id=run_id, map_index=map_index)
)

return {
"dag_id": task.dag_id,
"task_id": task.task_id,
Expand All @@ -1431,7 +1439,7 @@ def insert_mapping(run_id: str, task: Operator, map_index: int) -> dict[str, Any
"queue": task.queue,
"pool": task.pool,
"pool_slots": task.pool_slots,
"priority_weight": task.priority_weight_total,
"priority_weight": priority_weight,
"run_as_user": task.run_as_user,
"max_tries": task.retries,
"executor_config": task.executor_config,
Expand Down
36 changes: 35 additions & 1 deletion airflow/plugins_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@
from typing import TYPE_CHECKING, Any, Iterable

from airflow import settings
from airflow.task.priority_strategy import (
PriorityWeightStrategy,
airflow_priority_weight_strategies,
)
from airflow.utils.entry_points import entry_points_with_dist
from airflow.utils.file import find_path_from_directory
from airflow.utils.module_loading import import_string, qualname
Expand Down Expand Up @@ -68,6 +72,7 @@
registered_operator_link_classes: dict[str, type] | None = None
registered_ti_dep_classes: dict[str, type] | None = None
timetable_classes: dict[str, type[Timetable]] | None = None
priority_weight_strategy_classes: dict[str, type[PriorityWeightStrategy]] | None = None
"""
Mapping of class names to class of OperatorLinks registered by plugins.

Expand All @@ -89,6 +94,7 @@
"ti_deps",
"timetables",
"listeners",
"priority_weight_strategies",
}


Expand Down Expand Up @@ -169,6 +175,9 @@ class AirflowPlugin:

listeners: list[ModuleType | object] = []

# A list of priority weight strategy classes that can be used for calculating tasks weight priority.
priority_weight_strategies: list[type[PriorityWeightStrategy]] = []

@classmethod
def validate(cls):
"""Validate if plugin has a name."""
Expand Down Expand Up @@ -556,7 +565,7 @@ def get_plugin_info(attrs_to_dump: Iterable[str] | None = None) -> list[dict[str
for attr in attrs_to_dump:
if attr in ("global_operator_extra_links", "operator_extra_links"):
info[attr] = [f"<{qualname(d.__class__)} object>" for d in getattr(plugin, attr)]
elif attr in ("macros", "timetables", "hooks", "executors"):
elif attr in ("macros", "timetables", "hooks", "executors", "priority_weight_strategies"):
info[attr] = [qualname(d) for d in getattr(plugin, attr)]
elif attr == "listeners":
# listeners may be modules or class instances
Expand All @@ -577,3 +586,28 @@ def get_plugin_info(attrs_to_dump: Iterable[str] | None = None) -> list[dict[str
info[attr] = getattr(plugin, attr)
plugins_info.append(info)
return plugins_info


def initialize_priority_weight_strategy_plugins():
"""Collect priority weight strategy classes registered by plugins."""
global priority_weight_strategy_classes

if priority_weight_strategy_classes is not None:
return

ensure_plugins_loaded()

if plugins is None:
raise AirflowPluginException("Can't load plugins.")

log.debug("Initialize extra priority weight strategy plugins")

plugins_priority_weight_strategy_classes = {
qualname(priority_weight_strategy_class): priority_weight_strategy_class
for plugin in plugins
for priority_weight_strategy_class in plugin.priority_weight_strategies
}
priority_weight_strategy_classes = {
**airflow_priority_weight_strategies,
**plugins_priority_weight_strategy_classes,
}