Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions airflow/executors/dask_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

from airflow import configuration
from airflow.executors.base_executor import BaseExecutor
from airflow.utils.state import State


class DaskExecutor(BaseExecutor):
Expand All @@ -41,27 +40,28 @@ def start(self):

def execute_async(self, key, command, queue=None):
if queue is not None:
warnings.warning(
warnings.warn(
'DaskExecutor does not support queues. All tasks will be run '
'in the same cluster')

def airflow_run():
return subprocess.check_call(command, shell=True)

future = self.client.submit(airflow_run, pure=False)
self.futures[future] = key

def _process_future(self, future):
if future.done():
key = self.futures[future]
if future.exception():
self.change_state(key, State.FAILED)
self.fail(key)
self.logger.error("Failed to execute task: {}".format(
repr(future.exception())))
elif future.cancelled():
self.change_state(key, State.FAILED)
self.fail(key)
self.logger.error("Failed to execute task")
else:
self.change_state(key, State.SUCCESS)
self.success(key)
self.futures.pop(future)

def sync(self):
Expand All @@ -70,7 +70,7 @@ def sync(self):
self._process_future(future)

def end(self):
for future in distributed.as_completed(self.futures):
for future in distributed.as_completed(self.futures.copy()):
self._process_future(future)

def terminate(self):
Expand Down
74 changes: 5 additions & 69 deletions tests/executors/dask_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from airflow.operators.python_operator import PythonOperator

try:
from airflow.executors import DaskExecutor
from airflow.executors.dask_executor import DaskExecutor
from distributed import LocalCluster
SKIP_DASK = False
except ImportError:
Expand All @@ -48,6 +48,9 @@ def test_dask_executor_functions(self):

executor = DaskExecutor(cluster_address=cluster.scheduler_address)

# start the executor
executor.start()

success_command = 'echo 1'
fail_command = 'exit 1'

Expand All @@ -60,7 +63,7 @@ def test_dask_executor_functions(self):
k for k, v in executor.futures.items() if v == 'fail')

# wait for the futures to execute, with a timeout
timeout = datetime.datetime.now() + datetime.timedelta(seconds=0.5)
timeout = datetime.datetime.now() + datetime.timedelta(seconds=30)
while not (success_future.done() and fail_future.done()):
if datetime.datetime.now() > timeout:
raise ValueError(
Expand All @@ -75,73 +78,6 @@ def test_dask_executor_functions(self):
self.assertTrue(success_future.exception() is None)
self.assertTrue(fail_future.exception() is not None)

# tell the executor to shut down
executor.end()
self.assertTrue(len(executor.futures) == 0)

cluster.close()

@unittest.skipIf(SKIP_DASK, 'Dask unsupported by this configuration')
def test_submit_task_instance_to_dask_cluster(self):
"""
Test that the DaskExecutor properly submits tasks to the cluster
"""
cluster = LocalCluster(nanny=False)

executor = DaskExecutor(cluster_address=cluster.scheduler_address)

args = dict(
start_date=DEFAULT_DATE
)

def fail():
raise ValueError('Intentional failure.')

with DAG('test-dag', default_args=args) as dag:
# queue should be allowed, but ignored
success_operator = PythonOperator(
task_id='success',
python_callable=lambda: True,
queue='queue')

fail_operator = PythonOperator(
task_id='fail',
python_callable=fail)

success_ti = TaskInstance(
success_operator,
execution_date=DEFAULT_DATE)

fail_ti = TaskInstance(
fail_operator,
execution_date=DEFAULT_DATE)

# queue the tasks
executor.queue_task_instance(success_ti)
executor.queue_task_instance(fail_ti)

# the tasks haven't been submitted to the cluster yet
self.assertTrue(len(executor.futures) == 0)
# after the heartbeat, they have been submitted
executor.heartbeat()
self.assertTrue(len(executor.futures) == 2)

# wait a reasonable amount of time for the tasks to complete
for _ in range(2):
time.sleep(0.25)
executor.heartbeat()

# check that the futures were completed
if len(executor.futures) == 2:
raise ValueError('Failed to reach cluster before timeout.')
self.assertTrue(len(executor.futures) == 0)

# check that the taskinstances were updated
success_ti.refresh_from_db()
self.assertTrue(success_ti.state == State.SUCCESS)
fail_ti.refresh_from_db()
self.assertTrue(fail_ti.state == State.FAILED)

cluster.close()


Expand Down