Skip to content

Commit

Permalink
Experimental: Support custom weight_rule implementation to calculate …
Browse files Browse the repository at this point in the history
…the TI priority_weight
  • Loading branch information
hussein-awala committed Mar 17, 2024
1 parent 8839e0a commit 6581f7b
Show file tree
Hide file tree
Showing 20 changed files with 540 additions and 32 deletions.
13 changes: 9 additions & 4 deletions airflow/api_connexion/schemas/common_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@

from airflow.models.mappedoperator import MappedOperator
from airflow.serialization.serialized_objects import SerializedBaseOperator
from airflow.utils.weight_rule import WeightRule


class CronExpression(typing.NamedTuple):
Expand Down Expand Up @@ -138,9 +137,15 @@ def __init__(self, **metadata):
class WeightRuleField(fields.String):
"""Schema for WeightRule."""

def __init__(self, **metadata):
super().__init__(**metadata)
self.validators = [validate.OneOf(WeightRule.all_weight_rules()), *self.validators]
def _serialize(self, value, attr, obj, **kwargs):
from airflow.serialization.serialized_objects import encode_priority_weight_strategy

return encode_priority_weight_strategy(value)

def _deserialize(self, value, attr, data, **kwargs):
from airflow.serialization.serialized_objects import decode_priority_weight_strategy

return decode_priority_weight_strategy(value)


class TimezoneField(fields.String):
Expand Down
65 changes: 65 additions & 0 deletions airflow/example_dags/example_priority_weight_strategy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Example DAG demonstrating the usage of a custom PriorityWeightStrategy class."""

from __future__ import annotations

from typing import TYPE_CHECKING

import pendulum

from airflow.exceptions import AirflowException
from airflow.models.dag import DAG
from airflow.operators.python import PythonOperator

if TYPE_CHECKING:
from airflow.models import TaskInstance


def success_on_third_attempt(ti: TaskInstance, **context):
if ti.try_number < 3:
raise Exception("Not yet")


with DAG(
dag_id="example_priority_weight_strategy",
start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
catchup=False,
schedule="@daily",
tags=["example"],
default_args={
"retries": 3,
"retry_delay": pendulum.duration(seconds=10),
},
) as dag:
fixed_weight_task = PythonOperator(
task_id="fixed_weight_task",
python_callable=success_on_third_attempt,
weight_rule="downstream",
)

try:
decreasing_weight_task = PythonOperator(
task_id="decreasing_weight_task",
python_callable=success_on_third_attempt,
weight_rule="decreasing_priority_weight_strategy.DecreasingPriorityStrategy",
)
except AirflowException as e:
# In the unit tests, we don't load the example plugins, so the custom strategy is not available.
if "Unknown priority strategy" not in str(e):
raise
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING

from airflow.plugins_manager import AirflowPlugin
from airflow.task.priority_strategy import PriorityWeightStrategy

if TYPE_CHECKING:
from airflow.models import TaskInstance


# [START custom_priority_weight_strategy]
class DecreasingPriorityStrategy(PriorityWeightStrategy):
"""A priority weight strategy that decreases the priority weight with each attempt of the DAG task."""

def get_weight(self, ti: TaskInstance):
return max(3 - ti._try_number + 1, 1)


class DecreasingPriorityWeightStrategyPlugin(AirflowPlugin):
name = "decreasing_priority_weight_strategy_plugin"
priority_weight_strategies = [DecreasingPriorityStrategy]


# [END custom_priority_weight_strategy]
2 changes: 1 addition & 1 deletion airflow/executors/base_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def queue_task_instance(
self.queue_command(
task_instance,
command_list_to_run,
priority=task_instance.task.priority_weight_total,
priority=task_instance.priority_weight,
queue=task_instance.task.queue,
)

Expand Down
2 changes: 1 addition & 1 deletion airflow/executors/debug_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def queue_task_instance(
self.queue_command(
task_instance,
[str(task_instance)], # Just for better logging, it's not used anywhere
priority=task_instance.task.priority_weight_total,
priority=task_instance.priority_weight,
queue=task_instance.task.queue,
)
# Save params for TaskInstance._run_raw_task
Expand Down
22 changes: 18 additions & 4 deletions airflow/models/abstractoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import datetime
import inspect
import warnings
from functools import cached_property
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Collection, Iterable, Iterator, Sequence

Expand Down Expand Up @@ -53,6 +54,7 @@
from airflow.models.mappedoperator import MappedOperator
from airflow.models.operator import Operator
from airflow.models.taskinstance import TaskInstance
from airflow.task.priority_strategy import PriorityWeightStrategy
from airflow.utils.task_group import TaskGroup

DEFAULT_OWNER: str = conf.get_mandatory_value("operators", "default_owner")
Expand Down Expand Up @@ -97,7 +99,7 @@ class AbstractOperator(Templater, DAGNode):

operator_class: type[BaseOperator] | dict[str, Any]

weight_rule: str
weight_rule: PriorityWeightStrategy
priority_weight: int

# Defines the operator level extra links.
Expand Down Expand Up @@ -397,11 +399,23 @@ def priority_weight_total(self) -> int:
- WeightRule.DOWNSTREAM - adds priority weight of all downstream tasks
- WeightRule.UPSTREAM - adds priority weight of all upstream tasks
"""
if self.weight_rule == WeightRule.ABSOLUTE:
from airflow.task.priority_strategy import (
_AbsolutePriorityWeightStrategy,
_DownstreamPriorityWeightStrategy,
_UpstreamPriorityWeightStrategy,
)

warnings.warn(
"Accessing `priority_weight_total` from AbstractOperator instance is deprecated."
" Please use `priority_weight` from task instance instead.",
DeprecationWarning,
stacklevel=2,
)
if type(self.weight_rule) == _AbsolutePriorityWeightStrategy:
return self.priority_weight
elif self.weight_rule == WeightRule.DOWNSTREAM:
elif type(self.weight_rule) == _DownstreamPriorityWeightStrategy:
upstream = False
elif self.weight_rule == WeightRule.UPSTREAM:
elif type(self.weight_rule) == _UpstreamPriorityWeightStrategy:
upstream = True
else:
upstream = False
Expand Down
21 changes: 11 additions & 10 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
from airflow.models.taskinstance import TaskInstance, clear_task_instances
from airflow.models.taskmixin import DependencyMixin
from airflow.serialization.enums import DagAttributeTypes
from airflow.task.priority_strategy import PriorityWeightStrategy, validate_and_load_priority_weight_strategy
from airflow.ti_deps.deps.not_in_retry_period_dep import NotInRetryPeriodDep
from airflow.ti_deps.deps.not_previously_skipped_dep import NotPreviouslySkippedDep
from airflow.ti_deps.deps.prev_dagrun_dep import PrevDagrunDep
Expand All @@ -94,7 +95,6 @@
from airflow.utils.setup_teardown import SetupTeardownContext
from airflow.utils.trigger_rule import TriggerRule
from airflow.utils.types import NOTSET
from airflow.utils.weight_rule import WeightRule
from airflow.utils.xcom import XCOM_RETURN_KEY

if TYPE_CHECKING:
Expand Down Expand Up @@ -244,7 +244,7 @@ def partial(
retry_delay: timedelta | float | ArgNotSet = NOTSET,
retry_exponential_backoff: bool | ArgNotSet = NOTSET,
priority_weight: int | ArgNotSet = NOTSET,
weight_rule: str | ArgNotSet = NOTSET,
weight_rule: str | PriorityWeightStrategy | ArgNotSet = NOTSET,
sla: timedelta | None | ArgNotSet = NOTSET,
map_index_template: str | None | ArgNotSet = NOTSET,
max_active_tis_per_dag: int | None | ArgNotSet = NOTSET,
Expand Down Expand Up @@ -575,6 +575,13 @@ class derived from this one results in the creation of a task object,
significantly speeding up the task creation process as for very large
DAGs. Options can be set as string or using the constants defined in
the static class ``airflow.utils.WeightRule``
|experimental|
Since 2.9.0, Airflow allows to define custom priority weight strategy,
by creating a subclass of
``airflow.task.priority_strategy.PriorityWeightStrategy`` and registering
in a plugin, then providing the class path or the class instance via
``weight_rule`` parameter. The custom priority weight strategy will be
used to calculate the effective total priority weight of the task instance.
:param queue: which queue to target when running this job. Not
all executors implement queue management, the CeleryExecutor
does support targeting specific queues.
Expand Down Expand Up @@ -767,7 +774,7 @@ def __init__(
params: collections.abc.MutableMapping | None = None,
default_args: dict | None = None,
priority_weight: int = DEFAULT_PRIORITY_WEIGHT,
weight_rule: str = DEFAULT_WEIGHT_RULE,
weight_rule: str | PriorityWeightStrategy = DEFAULT_WEIGHT_RULE,
queue: str = DEFAULT_QUEUE,
pool: str | None = None,
pool_slots: int = DEFAULT_POOL_SLOTS,
Expand Down Expand Up @@ -918,13 +925,7 @@ def __init__(
f"received '{type(priority_weight)}'."
)
self.priority_weight = priority_weight
if not WeightRule.is_valid(weight_rule):
raise AirflowException(
f"The weight_rule must be one of "
f"{WeightRule.all_weight_rules},'{dag.dag_id if dag else ''}.{task_id}'; "
f"received '{weight_rule}'."
)
self.weight_rule = weight_rule
self.weight_rule = validate_and_load_priority_weight_strategy(weight_rule)
self.resources = coerce_resources(resources)
if task_concurrency and not max_active_tis_per_dag:
# TODO: Remove in Airflow 3.0
Expand Down
11 changes: 7 additions & 4 deletions airflow/models/mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
)
from airflow.models.pool import Pool
from airflow.serialization.enums import DagAttributeTypes
from airflow.task.priority_strategy import PriorityWeightStrategy, validate_and_load_priority_weight_strategy
from airflow.ti_deps.deps.mapped_task_expanded import MappedTaskIsExpanded
from airflow.typing_compat import Literal
from airflow.utils.context import context_update_for_unmapped
Expand Down Expand Up @@ -534,12 +535,14 @@ def priority_weight(self, value: int) -> None:
self.partial_kwargs["priority_weight"] = value

@property
def weight_rule(self) -> str: # type: ignore[override]
return self.partial_kwargs.get("weight_rule", DEFAULT_WEIGHT_RULE)
def weight_rule(self) -> PriorityWeightStrategy: # type: ignore[override]
return validate_and_load_priority_weight_strategy(
self.partial_kwargs.get("weight_rule", DEFAULT_WEIGHT_RULE)
)

@weight_rule.setter
def weight_rule(self, value: str) -> None:
self.partial_kwargs["weight_rule"] = value
def weight_rule(self, value: str | PriorityWeightStrategy) -> None:
self.partial_kwargs["weight_rule"] = validate_and_load_priority_weight_strategy(value)

@property
def sla(self) -> datetime.timedelta | None:
Expand Down
12 changes: 10 additions & 2 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,7 +917,11 @@ def _refresh_from_task(
task_instance.queue = task.queue
task_instance.pool = pool_override or task.pool
task_instance.pool_slots = task.pool_slots
task_instance.priority_weight = task.priority_weight_total
with contextlib.suppress(Exception):
# This method is called from the different places, and sometimes the TI is not fully initialized
task_instance.priority_weight = task_instance.task.weight_rule.get_weight(
task_instance # type: ignore[arg-type]
)
task_instance.run_as_user = task.run_as_user
# Do not set max_tries to task.retries here because max_tries is a cumulative
# value that needs to be stored in the db.
Expand Down Expand Up @@ -1421,6 +1425,10 @@ def insert_mapping(run_id: str, task: Operator, map_index: int) -> dict[str, Any
:meta private:
"""
priority_weight = task.weight_rule.get_weight(
TaskInstance(task=task, run_id=run_id, map_index=map_index)
)

return {
"dag_id": task.dag_id,
"task_id": task.task_id,
Expand All @@ -1431,7 +1439,7 @@ def insert_mapping(run_id: str, task: Operator, map_index: int) -> dict[str, Any
"queue": task.queue,
"pool": task.pool,
"pool_slots": task.pool_slots,
"priority_weight": task.priority_weight_total,
"priority_weight": priority_weight,
"run_as_user": task.run_as_user,
"max_tries": task.retries,
"executor_config": task.executor_config,
Expand Down

0 comments on commit 6581f7b

Please sign in to comment.