Skip to content

Commit

Permalink
[AIRFLOW-4591] Make default_pool a real pool (apache#5349)
Browse files Browse the repository at this point in the history
`non_pooled_task_slot_count` and `non_pooled_backfill_task_slot_count`
are removed in favor of a real pool, e.g. `default_pool`.

By default tasks are running in `default_pool`.
`default_pool` is initialized with 128 slots and user can change the
number of slots through UI/CLI. `default_pool` cannot be removed.
  • Loading branch information
milton0825 authored and feng-tao committed Jun 20, 2019
1 parent 7bacdde commit 2c99ec6
Show file tree
Hide file tree
Showing 20 changed files with 266 additions and 120 deletions.
9 changes: 9 additions & 0 deletions UPDATING.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,15 @@ assists users migrating to a new version.

## Airflow Master

### Removal of `non_pooled_task_slot_count` and `non_pooled_backfill_task_slot_count`

`non_pooled_task_slot_count` and `non_pooled_backfill_task_slot_count`
are removed in favor of a real pool, e.g. `default_pool`.

By default tasks are running in `default_pool`.
`default_pool` is initialized with 128 slots and user can change the
number of slots through UI/CLI. `default_pool` cannot be removed.

### Changes to Google Transfer Operator
To obtain pylint compatibility the `filter ` argument in `GcpTransferServiceOperationsListOperator`
has been renamed to `request_filter`.
Expand Down
3 changes: 3 additions & 0 deletions airflow/api/common/experimental/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ def delete_pool(name, session=None):
if not (name and name.strip()):
raise AirflowBadRequest("Pool name shouldn't be empty")

if name == Pool.DEFAULT_POOL_NAME:
raise AirflowBadRequest("default_pool cannot be deleted")

pool = session.query(Pool).filter_by(pool=name).first()
if pool is None:
raise PoolNotFound("Pool '%s' doesn't exist" % name)
Expand Down
8 changes: 0 additions & 8 deletions airflow/config_templates/default_airflow.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -132,14 +132,6 @@ dag_concurrency = 16
# Are DAGs paused by default at creation
dags_are_paused_at_creation = True

# When not using pools, tasks are run in the "default pool",
# whose size is guided by this config element
non_pooled_task_slot_count = 128

# When not using pools, the number of backfill tasks per backfill
# is limited by this config element
non_pooled_backfill_task_slot_count = %(non_pooled_task_slot_count)s

# The maximum number of active DAG runs per DAG
max_active_runs_per_dag = 16

Expand Down
1 change: 0 additions & 1 deletion airflow/config_templates/default_test.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ donot_pickle = False
dag_concurrency = 16
dags_are_paused_at_creation = False
fernet_key = {FERNET_KEY}
non_pooled_task_slot_count = 128
enable_xcom_pickling = False
killed_task_cleanup_time = 5
secure_mode = False
Expand Down
35 changes: 13 additions & 22 deletions airflow/jobs/backfill_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

from sqlalchemy.orm.session import make_transient

from airflow import configuration as conf
from airflow import executors, models
from airflow.exceptions import (
AirflowException,
Expand Down Expand Up @@ -542,32 +541,24 @@ def _per_task_process(task, key, ti, session=None):
self.log.debug('Adding %s to not_ready', ti)
ti_status.not_ready.add(key)

non_pool_slots = conf.getint('core', 'non_pooled_backfill_task_slot_count')

try:
for task in self.dag.topological_sort():
for key, ti in list(ti_status.to_run.items()):
if task.task_id != ti.task_id:
continue
if task.pool:
pool = session.query(models.Pool) \
.filter(models.Pool.pool == task.pool) \
.first()
if not pool:
raise PoolNotFound('Unknown pool: {}'.format(task.pool))

open_slots = pool.open_slots(session=session)
if open_slots <= 0:
raise NoAvailablePoolSlot(
"Not scheduling since there are "
"%s open slots in pool %s".format(
open_slots, task.pool))
else:
if non_pool_slots <= 0:
raise NoAvailablePoolSlot(
"Not scheduling since there are no "
"non_pooled_backfill_task_slot_count.")
non_pool_slots -= 1

pool = session.query(models.Pool) \
.filter(models.Pool.pool == task.pool) \
.first()
if not pool:
raise PoolNotFound('Unknown pool: {}'.format(task.pool))

open_slots = pool.open_slots(session=session)
if open_slots <= 0:
raise NoAvailablePoolSlot(
"Not scheduling since there are "
"%s open slots in pool %s".format(
open_slots, task.pool))

num_running_task_instances_in_dag = DAG.get_num_task_instances(
self.dag_id,
Expand Down
21 changes: 7 additions & 14 deletions airflow/jobs/scheduler_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,21 +893,14 @@ def _find_executable_task_instances(self, simple_dag_bag, states, session=None):
# any open slots in the pool.
for pool, task_instances in pool_to_task_instances.items():
pool_name = pool
if not pool:
# Arbitrary:
# If queued outside of a pool, trigger no more than
# non_pooled_task_slot_count
open_slots = models.Pool.default_pool_open_slots()
pool_name = models.Pool.default_pool_name
if pool not in pools:
self.log.warning(
"Tasks using non-existent pool '%s' will not be scheduled",
pool
)
open_slots = 0
else:
if pool not in pools:
self.log.warning(
"Tasks using non-existent pool '%s' will not be scheduled",
pool
)
open_slots = 0
else:
open_slots = pools[pool].open_slots(session=session)
open_slots = pools[pool].open_slots(session=session)

num_ready = len(task_instances)
self.log.info(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Make TaskInstance.pool not nullable
Revision ID: 6e96a59344a4
Revises: 939bb1e647c8
Create Date: 2019-06-13 21:51:32.878437
"""

from alembic import op
import dill
import sqlalchemy as sa
from sqlalchemy import Column, Float, Integer, PickleType, String
from sqlalchemy.ext.declarative import declarative_base

from airflow.utils.db import create_session
from airflow.utils.sqlalchemy import UtcDateTime


# revision identifiers, used by Alembic.
revision = '6e96a59344a4'
down_revision = '939bb1e647c8'
branch_labels = None
depends_on = None


Base = declarative_base()
ID_LEN = 250


class TaskInstance(Base):
"""
Task instances store the state of a task instance. This table is the
authority and single source of truth around what tasks have run and the
state they are in.
The SqlAlchemy model doesn't have a SqlAlchemy foreign key to the task or
dag model deliberately to have more control over transactions.
Database transactions on this table should insure double triggers and
any confusion around what task instances are or aren't ready to run
even while multiple schedulers may be firing task instances.
"""

__tablename__ = "task_instance"

task_id = Column(String(ID_LEN), primary_key=True)
dag_id = Column(String(ID_LEN), primary_key=True)
execution_date = Column(UtcDateTime, primary_key=True)
start_date = Column(UtcDateTime)
end_date = Column(UtcDateTime)
duration = Column(Float)
state = Column(String(20))
_try_number = Column('try_number', Integer, default=0)
max_tries = Column(Integer)
hostname = Column(String(1000))
unixname = Column(String(1000))
job_id = Column(Integer)
pool = Column(String(50), nullable=False)
queue = Column(String(256))
priority_weight = Column(Integer)
operator = Column(String(1000))
queued_dttm = Column(UtcDateTime)
pid = Column(Integer)
executor_config = Column(PickleType(pickler=dill))


def upgrade():
"""
Make TaskInstance.pool field not nullable.
"""
with create_session() as session:
session.query(TaskInstance)\
.filter(TaskInstance.pool.is_(None))\
.update({TaskInstance.pool: 'default_pool'},
synchronize_session=False) # Avoid select updated rows
session.commit()

# use batch_alter_table to support SQLite workaround
with op.batch_alter_table('task_instance') as batch_op:
batch_op.alter_column(
column_name='pool',
type_=sa.String(50),
nullable=False,
)


def downgrade():
"""
Make TaskInstance.pool field nullable.
"""
# use batch_alter_table to support SQLite workaround
with op.batch_alter_table('task_instance') as batch_op:
batch_op.alter_column(
column_name='pool',
type_=sa.String(50),
nullable=True,
)

with create_session() as session:
session.query(TaskInstance)\
.filter(TaskInstance.pool == 'default_pool')\
.update({TaskInstance.pool: None},
synchronize_session=False) # Avoid select updated rows
session.commit()
3 changes: 2 additions & 1 deletion airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from airflow.exceptions import AirflowException
from airflow.lineage import prepare_lineage, apply_lineage, DataSet
from airflow.models.dag import DAG
from airflow.models.pool import Pool
from airflow.models.taskinstance import TaskInstance, clear_task_instances
from airflow.models.xcom import XCOM_RETURN_KEY
from airflow.ti_deps.deps.not_in_retry_period_dep import NotInRetryPeriodDep
Expand Down Expand Up @@ -258,7 +259,7 @@ def __init__(
priority_weight: int = 1,
weight_rule: str = WeightRule.DOWNSTREAM,
queue: str = configuration.conf.get('celery', 'default_queue'),
pool: Optional[str] = None,
pool: str = Pool.DEFAULT_POOL_NAME,
sla: Optional[timedelta] = None,
execution_timeout: Optional[timedelta] = None,
on_failure_callback: Optional[Callable] = None,
Expand Down
17 changes: 8 additions & 9 deletions airflow/models/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

from sqlalchemy import Column, Integer, String, Text, func

from airflow import conf
from airflow.models.base import Base
from airflow.utils.state import State
from airflow.utils.db import provide_session
Expand All @@ -33,20 +32,20 @@ class Pool(Base):
slots = Column(Integer, default=0)
description = Column(Text)

default_pool_name = 'not_pooled'
DEFAULT_POOL_NAME = 'default_pool'

def __repr__(self):
return self.pool

@staticmethod
@provide_session
def default_pool_open_slots(session):
from airflow.models import TaskInstance as TI # To avoid circular imports
total_slots = conf.getint('core', 'non_pooled_task_slot_count')
used_slots = session.query(func.count()).filter(
TI.pool == Pool.default_pool_name).filter(
TI.state.in_([State.RUNNING, State.QUEUED])).scalar()
return total_slots - used_slots
def get_pool(pool_name, session=None):
return session.query(Pool).filter(Pool.pool == pool_name).first()

@staticmethod
@provide_session
def get_default_pool(session=None):
return Pool.get_pool(Pool.DEFAULT_POOL_NAME, session=session)

def to_json(self):
return {
Expand Down
2 changes: 1 addition & 1 deletion airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ class TaskInstance(Base, LoggingMixin):
hostname = Column(String(1000))
unixname = Column(String(1000))
job_id = Column(Integer)
pool = Column(String(50))
pool = Column(String(50), nullable=False)
queue = Column(String(256))
priority_weight = Column(Integer)
operator = Column(String(1000))
Expand Down
16 changes: 16 additions & 0 deletions airflow/utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import contextlib

from airflow import settings
from airflow.configuration import conf
from airflow.utils.log.logging_mixin import LoggingMixin

log = LoggingMixin().log
Expand Down Expand Up @@ -78,6 +79,20 @@ def merge_conn(conn, session=None):
session.commit()


@provide_session
def add_default_pool_if_not_exists(session=None):
from airflow.models.pool import Pool
if not Pool.get_pool(Pool.DEFAULT_POOL_NAME, session=session):
default_pool = Pool(
pool=Pool.DEFAULT_POOL_NAME,
slots=conf.getint(section='core', key='non_pooled_task_slot_count',
fallback=128),
description="Default pool",
)
session.add(default_pool)
session.commit()


def initdb():
from airflow import models
from airflow.models import Connection
Expand Down Expand Up @@ -311,6 +326,7 @@ def upgradedb():
config.set_main_option('script_location', directory.replace('%', '%%'))
config.set_main_option('sqlalchemy.url', settings.SQL_ALCHEMY_CONN.replace('%', '%%'))
command.upgrade(config, 'heads')
add_default_pool_if_not_exists()


def resetdb():
Expand Down
4 changes: 4 additions & 0 deletions airflow/www/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -2109,6 +2109,10 @@ class PoolModelView(AirflowModelView):
@action('muldelete', 'Delete', 'Are you sure you want to delete selected records?',
single=False)
def action_muldelete(self, items):
if any(item.pool == models.Pool.DEFAULT_POOL_NAME for item in items):
flash("default_pool cannot be deleted", 'error')
self.update_redirect()
return redirect(self.get_redirect())
self.datamodel.delete_all(items)
self.update_redirect()
return redirect(self.get_redirect())
Expand Down
9 changes: 5 additions & 4 deletions tests/api/client/test_local_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,18 +124,19 @@ def test_get_pools(self):
self.client.create_pool(name='foo1', slots=1, description='')
self.client.create_pool(name='foo2', slots=2, description='')
pools = sorted(self.client.get_pools(), key=lambda p: p[0])
self.assertEqual(pools, [('foo1', 1, ''), ('foo2', 2, '')])
self.assertEqual(pools, [('default_pool', 128, 'Default pool'),
('foo1', 1, ''), ('foo2', 2, '')])

def test_create_pool(self):
pool = self.client.create_pool(name='foo', slots=1, description='')
self.assertEqual(pool, ('foo', 1, ''))
with create_session() as session:
self.assertEqual(session.query(models.Pool).count(), 1)
self.assertEqual(session.query(models.Pool).count(), 2)

def test_delete_pool(self):
self.client.create_pool(name='foo', slots=1, description='')
with create_session() as session:
self.assertEqual(session.query(models.Pool).count(), 1)
self.assertEqual(session.query(models.Pool).count(), 2)
self.client.delete_pool(name='foo')
with create_session() as session:
self.assertEqual(session.query(models.Pool).count(), 0)
self.assertEqual(session.query(models.Pool).count(), 1)

0 comments on commit 2c99ec6

Please sign in to comment.