Skip to content

Commit

Permalink
Enable pools to consider deferred tasks (#32709)
Browse files Browse the repository at this point in the history
* Makes pools respect deferrable tasks (with extra setting)

See #21243

This commit makes pools consider deferred tasks if the `include_deferred` flag is set. By default a pool will not consider deferred tasks as occupied slots, but still show the number of deferred tasks in its stats.

---------

Co-authored-by: Tzu-ping Chung <uranusjr@gmail.com>
  • Loading branch information
Usiel and uranusjr authored Aug 9, 2023
1 parent f82acc1 commit 70a050b
Show file tree
Hide file tree
Showing 37 changed files with 549 additions and 166 deletions.
3 changes: 2 additions & 1 deletion airflow/api/client/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,13 @@ def get_pools(self):
"""Get all pools."""
raise NotImplementedError()

def create_pool(self, name, slots, description):
def create_pool(self, name, slots, description, include_deferred):
"""Create a pool.
:param name: pool name
:param slots: pool slots amount
:param description: pool description
:param include_deferred: include deferred tasks in pool calculations
"""
raise NotImplementedError()

Expand Down
9 changes: 6 additions & 3 deletions airflow/api/client/json_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,23 +108,26 @@ def get_pools(self):
pools = self._request(url)
return [(p["pool"], p["slots"], p["description"]) for p in pools]

def create_pool(self, name: str, slots: int, description: str):
def create_pool(self, name: str, slots: int, description: str, include_deferred: bool):
"""Create a new pool.
:param name: The name of the pool to create.
:param slots: The number of slots in the pool.
:param description: A description of the pool.
:param include_deferred: include deferred tasks in pool calculations
:return: A tuple containing the name of the pool, the number of slots in the pool,
and a description of the pool.
a description of the pool and the include_deferred flag.
"""
endpoint = "/api/experimental/pools"
data = {
"name": name,
"slots": slots,
"description": description,
"include_deferred": include_deferred,
}
response = self._request(urljoin(self._api_base_url, endpoint), method="POST", json=data)
return response["pool"], response["slots"], response["description"]
return response["pool"], response["slots"], response["description"], response["include_deferred"]

def delete_pool(self, name: str):
"""Delete a pool.
Expand Down
10 changes: 6 additions & 4 deletions airflow/api/client/local_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,12 @@ def get_pool(self, name):
pool = Pool.get_pool(pool_name=name)
if not pool:
raise PoolNotFound(f"Pool {name} not found")
return pool.pool, pool.slots, pool.description
return pool.pool, pool.slots, pool.description, pool.include_deferred

def get_pools(self):
return [(p.pool, p.slots, p.description) for p in Pool.get_pools()]
return [(p.pool, p.slots, p.description, p.include_deferred) for p in Pool.get_pools()]

def create_pool(self, name, slots, description):
def create_pool(self, name, slots, description, include_deferred):
if not (name and name.strip()):
raise AirflowBadRequest("Pool name shouldn't be empty")
pool_name_length = Pool.pool.property.columns[0].type.length
Expand All @@ -78,7 +78,9 @@ def create_pool(self, name, slots, description):
slots = int(slots)
except ValueError:
raise AirflowBadRequest(f"Bad value for `slots`: {slots}")
pool = Pool.create_or_update_pool(name=name, slots=slots, description=description)
pool = Pool.create_or_update_pool(
name=name, slots=slots, description=description, include_deferred=include_deferred
)
return pool.pool, pool.slots, pool.description

def delete_pool(self, name):
Expand Down
2 changes: 1 addition & 1 deletion airflow/api/common/experimental/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def create_pool(name, slots, description, session: Session = NEW_SESSION):
session.expire_on_commit = False
pool = session.scalar(select(Pool).filter_by(pool=name).limit(1))
if pool is None:
pool = Pool(pool=name, slots=slots, description=description)
pool = Pool(pool=name, slots=slots, description=description, include_deferred=False)
session.add(pool)
else:
pool.slots = slots
Expand Down
4 changes: 2 additions & 2 deletions airflow/api_connexion/endpoints/pool_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,10 @@ def patch_pool(
) -> APIResponse:
"""Update a pool."""
request_dict = get_json_request_dict()
# Only slots can be modified in 'default_pool'
# Only slots and include_deferred can be modified in 'default_pool'
try:
if pool_name == Pool.DEFAULT_POOL_NAME and request_dict["name"] != Pool.DEFAULT_POOL_NAME:
if update_mask and len(update_mask) == 1 and update_mask[0].strip() == "slots":
if update_mask and all(mask.strip() in {"slots", "include_deferred"} for mask in update_mask):
pass
else:
raise BadRequest(detail="Default Pool's name can't be modified")
Expand Down
15 changes: 14 additions & 1 deletion airflow/api_connexion/openapi/v1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3158,7 +3158,7 @@ components:
occupied_slots:
type: integer
readOnly: true
description: The number of slots used by running/queued tasks at the moment.
description: The number of slots used by running/queued tasks at the moment. May include deferred tasks if 'include_deferred' is set to true.
running_slots:
type: integer
readOnly: true
Expand All @@ -3175,13 +3175,26 @@ components:
type: integer
readOnly: true
description: The number of slots used by scheduled tasks at the moment.
deferred_slots:
type: integer
readOnly: true
description: |
The number of slots used by deferred tasks at the moment. Relevant if 'include_deferred' is set to true.
*New in version 2.7.0*
description:
type: string
description: |
The description of the pool.
*New in version 2.3.0*
nullable: true
include_deferred:
type: boolean
description: |
If set to true, deferred tasks are considered when calculating open pool slots.
*New in version 2.7.0*
PoolCollection:
type: object
Expand Down
8 changes: 8 additions & 0 deletions airflow/api_connexion/schemas/pool_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ class Meta:
queued_slots = fields.Method("get_queued_slots", dump_only=True)
scheduled_slots = fields.Method("get_scheduled_slots", dump_only=True)
open_slots = fields.Method("get_open_slots", dump_only=True)
deferred_slots = fields.Method("get_deferred_slots", dump_only=True)
description = auto_field()
# we skip auto_field() here to be compatible with the manual validation in the pool_endpoint module
include_deferred = fields.Boolean(load_default=False)

@staticmethod
def get_occupied_slots(obj: Pool) -> int:
Expand All @@ -61,6 +64,11 @@ def get_scheduled_slots(obj: Pool) -> int:
"""Returns the scheduled slots of the pool."""
return obj.scheduled_slots()

@staticmethod
def get_deferred_slots(obj: Pool) -> int:
"""Returns the deferred slots of the pool."""
return obj.deferred_slots()

@staticmethod
def get_open_slots(obj: Pool) -> float:
"""Returns the open slots of the pool."""
Expand Down
16 changes: 13 additions & 3 deletions airflow/cli/cli_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,9 @@ def string_lower_type(val):
ARG_POOL_NAME = Arg(("pool",), metavar="NAME", help="Pool name")
ARG_POOL_SLOTS = Arg(("slots",), type=int, help="Pool slots")
ARG_POOL_DESCRIPTION = Arg(("description",), help="Pool description")
ARG_POOL_INCLUDE_DEFERRED = Arg(
("--include-deferred",), help="Include deferred tasks in calculations for Pool", action="store_true"
)
ARG_POOL_IMPORT = Arg(
("file",),
metavar="FILEPATH",
Expand All @@ -521,8 +524,8 @@ def string_lower_type(val):
textwrap.dedent(
"""
{
"pool_1": {"slots": 5, "description": ""},
"pool_2": {"slots": 10, "description": "test"}
"pool_1": {"slots": 5, "description": "", "include_deferred": true},
"pool_2": {"slots": 10, "description": "test", "include_deferred": false}
}"""
),
" " * 4,
Expand Down Expand Up @@ -1456,7 +1459,14 @@ class GroupCommand(NamedTuple):
name="set",
help="Configure pool",
func=lazy_load_command("airflow.cli.commands.pool_command.pool_set"),
args=(ARG_POOL_NAME, ARG_POOL_SLOTS, ARG_POOL_DESCRIPTION, ARG_OUTPUT, ARG_VERBOSE),
args=(
ARG_POOL_NAME,
ARG_POOL_SLOTS,
ARG_POOL_DESCRIPTION,
ARG_POOL_INCLUDE_DEFERRED,
ARG_OUTPUT,
ARG_VERBOSE,
),
),
ActionCommand(
name="delete",
Expand Down
18 changes: 14 additions & 4 deletions airflow/cli/commands/pool_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def _show_pools(pools, output):
"pool": x[0],
"slots": x[1],
"description": x[2],
"include_deferred": x[3],
},
)

Expand Down Expand Up @@ -69,7 +70,9 @@ def pool_get(args):
def pool_set(args):
"""Creates new pool with a given name and slots."""
api_client = get_current_api_client()
api_client.create_pool(name=args.pool, slots=args.slots, description=args.description)
api_client.create_pool(
name=args.pool, slots=args.slots, description=args.description, include_deferred=args.include_deferred
)
print(f"Pool {args.pool} created")


Expand Down Expand Up @@ -119,8 +122,15 @@ def pool_import_helper(filepath):
pools = []
failed = []
for k, v in pools_json.items():
if isinstance(v, dict) and len(v) == 2:
pools.append(api_client.create_pool(name=k, slots=v["slots"], description=v["description"]))
if isinstance(v, dict) and "slots" in v and "description" in v:
pools.append(
api_client.create_pool(
name=k,
slots=v["slots"],
description=v["description"],
include_deferred=v.get("include_deferred", False),
)
)
else:
failed.append(k)
return pools, failed
Expand All @@ -132,7 +142,7 @@ def pool_export_helper(filepath):
pool_dict = {}
pools = api_client.get_pools()
for pool in pools:
pool_dict[pool[0]] = {"slots": pool[1], "description": pool[2]}
pool_dict[pool[0]] = {"slots": pool[1], "description": pool[2], "include_deferred": pool[3]}
with open(filepath, "w") as poolfile:
poolfile.write(json.dumps(pool_dict, sort_keys=True, indent=4))
return pools
2 changes: 2 additions & 0 deletions airflow/jobs/scheduler_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1571,10 +1571,12 @@ def _emit_pool_metrics(self, session: Session = NEW_SESSION) -> None:
Stats.gauge(f"pool.open_slots.{pool_name}", slot_stats["open"])
Stats.gauge(f"pool.queued_slots.{pool_name}", slot_stats["queued"])
Stats.gauge(f"pool.running_slots.{pool_name}", slot_stats["running"])
Stats.gauge(f"pool.deferred_slots.{pool_name}", slot_stats["deferred"])
# Same metrics with tagging
Stats.gauge("pool.open_slots", slot_stats["open"], tags={"pool_name": pool_name})
Stats.gauge("pool.queued_slots", slot_stats["queued"], tags={"pool_name": pool_name})
Stats.gauge("pool.running_slots", slot_stats["running"], tags={"pool_name": pool_name})
Stats.gauge("pool.deferred_slots", slot_stats["deferred"], tags={"pool_name": pool_name})

@provide_session
def adopt_or_reset_orphaned_tasks(self, session: Session = NEW_SESSION) -> int:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""add include_deferred column to pool
Revision ID: 405de8318b3a
Revises: 788397e78828
Create Date: 2023-07-20 04:22:21.007342
"""

import sqlalchemy as sa
from alembic import op


# revision identifiers, used by Alembic.
revision = "405de8318b3a"
down_revision = "788397e78828"
branch_labels = None
depends_on = None
airflow_version = "2.7.0"


def upgrade():
"""Apply add include_deferred column to pool"""
with op.batch_alter_table("slot_pool") as batch_op:
batch_op.add_column(
sa.Column("include_deferred", sa.Boolean, nullable=False, server_default=sa.false())
)


def downgrade():
"""Unapply add include_deferred column to pool"""
with op.batch_alter_table("slot_pool") as batch_op:
batch_op.drop_column("include_deferred")
Loading

0 comments on commit 70a050b

Please sign in to comment.