Skip to content

Commit

Permalink
Initial DaskRunner for Beam (apache#22421)
Browse files Browse the repository at this point in the history
* WIP: Created a skeleton dask runner implementation.

* WIP: Idea for a translation evaluator.

* Added overrides and a visitor that translates operations.

* Fixed a dataclass typo.

* Expanded translations.

* Core idea seems to be kinda working...

* First iteration on DaskRunnerResult (keep track of pipeline state).

* Added minimal set of DaskRunner options.

* WIP: Alllmost got asserts to work! The current status is:
- CoGroupByKey is broken due to how tags are used with GroupByKey
- GroupByKey should output `[('0', None), ('1', 1)]`, however it actually outputs: [(None, ('1', 1)), (None, ('0', None))]
- Once that is fixed, we may have test pipelines work on Dask.

* With a great 1-liner from @pabloem, groupby is fixed! Now, all three initial tests pass.

* Self-review: Cleaned up dask runner impl.

* Self-review: Remove TODOs, delete commented out code, other cleanup.

* First pass at linting rules.

* WIP, include dask dependencies + test setup.

* WIP: maybe better dask deps?

* Skip dask tests depending on successful import.

* Fixed setup.py (missing `,`).

* Added an additional comma.

* Moved skipping logic to be above dask import.

* Fix lint issues with dask runner tests.

* Adding destination for client address.

* Changing to async produces a timeout error instead of stuck in infinite loop.

* Close client during `wait_until_finish`; rm async.

* Supporting side-inputs for ParDo.

* Revert "Close client during `wait_until_finish`; rm async."

This reverts commit 09365f6.

* Revert "Changing to async produces a timeout error instead of stuck in infinite loop."

This reverts commit 676d752.

* Adding -dask tox targets onto the gradle build

* wip - added print stmt.

* wip - prove side inputs is set.

* wip - prove side inputs is set in Pardo.

* wip - rm asserts, add print

* wip - adding named inputs...

* Experiments: non-named side inputs + del `None` in named inputs.

* None --> 'None'

* No default side input.

* Pass along args + kwargs.

* Applied yapf to dask sources.

* Dask sources passing pylint.

* Added dask extra to docs gen tox env.

* Applied yapf from tox.

* Include dask in mypy checks.

* Upgrading mypy support to python 3.8 since py37 support is deprecated in dask.

* Manually installing an old version of dask before 3.7 support was dropped.

* fix lint: line too long.

* Fixed type errors with DaskRunnerResult. Disabled mypy type checking in dask.

* Fix pytype errors (in transform_evaluator).

* Ran isort.

* Ran yapf again.

* Fix imports (one per line)

* isort -- alphabetical.

* Added feature to CHANGES.md.

* ran yapf via tox on linux machine

* Change an import to pass CI.

* Skip isort error; needed to get CI to pass.

* Skip test logic may favor better with isort.

* (Maybe) the last isort fix.

* Tested pipeline options (added one fix).

* Improve formatting of test.

* Self-review: removing side inputs.

In addition, adding a more helpful property to the base DaskBagOp (tranform).

* add dask to coverage suite in tox.

* Capture value error in assert.

* Change timeout value to 600 seconds.

* ignoring broken test

* Update CHANGES.md

* Using reflection to test the Dask client constructor.

* Better method of inspecting the constructor parameters (thanks @TomAugspurger!).

Co-authored-by: Pablo E <pabloem@apache.org>
Co-authored-by: Pablo <pabloem@users.noreply.github.com>
  • Loading branch information
3 people authored and ruslan-ikhsan committed Nov 11, 2022
1 parent bc13c81 commit 2849da6
Show file tree
Hide file tree
Showing 10 changed files with 563 additions and 4 deletions.
2 changes: 2 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
## Highlights

* Python 3.10 support in Apache Beam ([#21458](https://github.com/apache/beam/issues/21458)).
* An initial implementation of a runner that allows us to run Beam pipelines on Dask. Try it out and give us feedback! (Python) ([#18962](https://github.com/apache/beam/issues/18962)).


## I/Os
Expand All @@ -81,6 +82,7 @@
* X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)).
* Dataframe wrapper added in Go SDK via Cross-Language (with automatic expansion service). (Go) ([#23384](https://github.com/apache/beam/issues/23384)).
* Name all Java threads to aid in debugging ([#23049](https://github.com/apache/beam/issues/23049)).
* An initial implementation of a runner that allows us to run Beam pipelines on Dask. (Python) ([#18962](https://github.com/apache/beam/issues/18962)).

## Breaking Changes

Expand Down
16 changes: 16 additions & 0 deletions sdks/python/apache_beam/runners/dask/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
182 changes: 182 additions & 0 deletions sdks/python/apache_beam/runners/dask/dask_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""DaskRunner, executing remote jobs on Dask.distributed.
The DaskRunner is a runner implementation that executes a graph of
transformations across processes and workers via Dask distributed's
scheduler.
"""
import argparse
import dataclasses
import typing as t

from apache_beam import pvalue
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.pipeline import AppliedPTransform
from apache_beam.pipeline import PipelineVisitor
from apache_beam.runners.dask.overrides import dask_overrides
from apache_beam.runners.dask.transform_evaluator import TRANSLATIONS
from apache_beam.runners.dask.transform_evaluator import NoOp
from apache_beam.runners.direct.direct_runner import BundleBasedDirectRunner
from apache_beam.runners.runner import PipelineResult
from apache_beam.runners.runner import PipelineState
from apache_beam.utils.interactive_utils import is_in_notebook


class DaskOptions(PipelineOptions):
@staticmethod
def _parse_timeout(candidate):
try:
return int(candidate)
except (TypeError, ValueError):
import dask
return dask.config.no_default

@classmethod
def _add_argparse_args(cls, parser: argparse.ArgumentParser) -> None:
parser.add_argument(
'--dask_client_address',
dest='address',
type=str,
default=None,
help='Address of a dask Scheduler server. Will default to a '
'`dask.LocalCluster()`.')
parser.add_argument(
'--dask_connection_timeout',
dest='timeout',
type=DaskOptions._parse_timeout,
help='Timeout duration for initial connection to the scheduler.')
parser.add_argument(
'--dask_scheduler_file',
dest='scheduler_file',
type=str,
default=None,
help='Path to a file with scheduler information if available.')
# TODO(alxr): Add options for security.
parser.add_argument(
'--dask_client_name',
dest='name',
type=str,
default=None,
help='Gives the client a name that will be included in logs generated '
'on the scheduler for matters relating to this client.')
parser.add_argument(
'--dask_connection_limit',
dest='connection_limit',
type=int,
default=512,
help='The number of open comms to maintain at once in the connection '
'pool.')


@dataclasses.dataclass
class DaskRunnerResult(PipelineResult):
from dask import distributed

client: distributed.Client
futures: t.Sequence[distributed.Future]

def __post_init__(self):
super().__init__(PipelineState.RUNNING)

def wait_until_finish(self, duration=None) -> str:
try:
if duration is not None:
# Convert milliseconds to seconds
duration /= 1000
self.client.wait_for_workers(timeout=duration)
self.client.gather(self.futures, errors='raise')
self._state = PipelineState.DONE
except: # pylint: disable=broad-except
self._state = PipelineState.FAILED
raise
return self._state

def cancel(self) -> str:
self._state = PipelineState.CANCELLING
self.client.cancel(self.futures)
self._state = PipelineState.CANCELLED
return self._state

def metrics(self):
# TODO(alxr): Collect and return metrics...
raise NotImplementedError('collecting metrics will come later!')


class DaskRunner(BundleBasedDirectRunner):
"""Executes a pipeline on a Dask distributed client."""
@staticmethod
def to_dask_bag_visitor() -> PipelineVisitor:
from dask import bag as db

@dataclasses.dataclass
class DaskBagVisitor(PipelineVisitor):
bags: t.Dict[AppliedPTransform,
db.Bag] = dataclasses.field(default_factory=dict)

def visit_transform(self, transform_node: AppliedPTransform) -> None:
op_class = TRANSLATIONS.get(transform_node.transform.__class__, NoOp)
op = op_class(transform_node)

inputs = list(transform_node.inputs)
if inputs:
bag_inputs = []
for input_value in inputs:
if isinstance(input_value, pvalue.PBegin):
bag_inputs.append(None)

prev_op = input_value.producer
if prev_op in self.bags:
bag_inputs.append(self.bags[prev_op])

if len(bag_inputs) == 1:
self.bags[transform_node] = op.apply(bag_inputs[0])
else:
self.bags[transform_node] = op.apply(bag_inputs)

else:
self.bags[transform_node] = op.apply(None)

return DaskBagVisitor()

@staticmethod
def is_fnapi_compatible():
return False

def run_pipeline(self, pipeline, options):
# TODO(alxr): Create interactive notebook support.
if is_in_notebook():
raise NotImplementedError('interactive support will come later!')

try:
import dask.distributed as ddist
except ImportError:
raise ImportError(
'DaskRunner is not available. Please install apache_beam[dask].')

dask_options = options.view_as(DaskOptions).get_all_options(
drop_default=True)
client = ddist.Client(**dask_options)

pipeline.replace_all(dask_overrides())

dask_visitor = self.to_dask_bag_visitor()
pipeline.visit(dask_visitor)

futures = client.compute(list(dask_visitor.bags.values()))
return DaskRunnerResult(client, futures)
94 changes: 94 additions & 0 deletions sdks/python/apache_beam/runners/dask/dask_runner_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import inspect
import unittest

import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.testing import test_pipeline
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to

try:
from apache_beam.runners.dask.dask_runner import DaskOptions
from apache_beam.runners.dask.dask_runner import DaskRunner
import dask
import dask.distributed as ddist
except (ImportError, ModuleNotFoundError):
raise unittest.SkipTest('Dask must be installed to run tests.')


class DaskOptionsTest(unittest.TestCase):
def test_parses_connection_timeout__defaults_to_none(self):
default_options = PipelineOptions([])
default_dask_options = default_options.view_as(DaskOptions)
self.assertEqual(None, default_dask_options.timeout)

def test_parses_connection_timeout__parses_int(self):
conn_options = PipelineOptions('--dask_connection_timeout 12'.split())
dask_conn_options = conn_options.view_as(DaskOptions)
self.assertEqual(12, dask_conn_options.timeout)

def test_parses_connection_timeout__handles_bad_input(self):
err_options = PipelineOptions('--dask_connection_timeout foo'.split())
dask_err_options = err_options.view_as(DaskOptions)
self.assertEqual(dask.config.no_default, dask_err_options.timeout)

def test_parser_destinations__agree_with_dask_client(self):
options = PipelineOptions(
'--dask_client_address localhost:8080 --dask_connection_timeout 600 '
'--dask_scheduler_file foobar.cfg --dask_client_name charlie '
'--dask_connection_limit 1024'.split())
dask_options = options.view_as(DaskOptions)

# Get the argument names for the constructor.
client_args = list(inspect.signature(ddist.Client).parameters)

for opt_name in dask_options.get_all_options(drop_default=True).keys():
with self.subTest(f'{opt_name} in dask.distributed.Client constructor'):
self.assertIn(opt_name, client_args)


class DaskRunnerRunPipelineTest(unittest.TestCase):
"""Test class used to introspect the dask runner via a debugger."""
def setUp(self) -> None:
self.pipeline = test_pipeline.TestPipeline(runner=DaskRunner())

def test_create(self):
with self.pipeline as p:
pcoll = p | beam.Create([1])
assert_that(pcoll, equal_to([1]))

def test_create_and_map(self):
def double(x):
return x * 2

with self.pipeline as p:
pcoll = p | beam.Create([1]) | beam.Map(double)
assert_that(pcoll, equal_to([2]))

def test_create_map_and_groupby(self):
def double(x):
return x * 2, x

with self.pipeline as p:
pcoll = p | beam.Create([1]) | beam.Map(double) | beam.GroupByKey()
assert_that(pcoll, equal_to([(2, [1])]))


if __name__ == '__main__':
unittest.main()
Loading

0 comments on commit 2849da6

Please sign in to comment.