From b5222218360f5639790e714cdc025f0ede2e21ef Mon Sep 17 00:00:00 2001 From: Maria Garcia Herrero Date: Thu, 5 Jan 2017 19:02:58 -0800 Subject: [PATCH 1/3] Add StaticValueProvider class for FileBasedSource I/O Transforms --- sdks/python/apache_beam/io/avroio_test.py | 9 +++- sdks/python/apache_beam/io/filebasedsource.py | 42 +++++++++++----- .../apache_beam/io/filebasedsource_test.py | 8 ++- sdks/python/apache_beam/io/textio_test.py | 5 +- .../apache_beam/transforms/display_test.py | 28 +++++++++++ .../apache_beam/utils/pipeline_options.py | 14 ++++++ .../utils/pipeline_options_test.py | 22 ++++++++ .../apache_beam/utils/value_provider.py | 50 +++++++++++++++++++ 8 files changed, 161 insertions(+), 17 deletions(-) create mode 100644 sdks/python/apache_beam/utils/value_provider.py diff --git a/sdks/python/apache_beam/io/avroio_test.py b/sdks/python/apache_beam/io/avroio_test.py index aed468dc67b1..91abe62f0683 100644 --- a/sdks/python/apache_beam/io/avroio_test.py +++ b/sdks/python/apache_beam/io/avroio_test.py @@ -29,6 +29,8 @@ from apache_beam.transforms.display_test import DisplayDataItemMatcher from apache_beam.transforms.util import assert_that from apache_beam.transforms.util import equal_to +from apache_beam.utils.value_provider import StaticValueProvider + # Importing following private class for testing purposes. from apache_beam.io.avroio import _AvroSource as AvroSource @@ -164,7 +166,9 @@ def test_source_display_data(self): # No extra avro parameters for AvroSource. expected_items = [ DisplayDataItemMatcher('compression', 'auto'), - DisplayDataItemMatcher('file_pattern', file_name)] + DisplayDataItemMatcher('file_pattern', + str(StaticValueProvider(str, file_name)))] + hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items)) def test_read_display_data(self): @@ -175,7 +179,8 @@ def test_read_display_data(self): # No extra avro parameters for AvroSource. expected_items = [ DisplayDataItemMatcher('compression', 'auto'), - DisplayDataItemMatcher('file_pattern', file_name)] + DisplayDataItemMatcher('file_pattern', + str(StaticValueProvider(str, file_name)))] hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items)) def test_sink_display_data(self): diff --git a/sdks/python/apache_beam/io/filebasedsource.py b/sdks/python/apache_beam/io/filebasedsource.py index 1bfde258f0b2..181675795846 100644 --- a/sdks/python/apache_beam/io/filebasedsource.py +++ b/sdks/python/apache_beam/io/filebasedsource.py @@ -36,6 +36,8 @@ from apache_beam.io import iobase from apache_beam.io import range_trackers from apache_beam.transforms.display import DisplayDataItem +from apache_beam.utils.value_provider import ValueProvider +from apache_beam.utils.value_provider import StaticValueProvider MAX_NUM_THREADS_FOR_SIZE_ESTIMATION = 25 @@ -55,7 +57,8 @@ def __init__(self, """Initializes ``FileBasedSource``. Args: - file_pattern: the file glob to read. + file_pattern: the file glob to read or a ValueProvider (place holder to + inject a runtime value). min_bundle_size: minimum size of bundles that should be generated when performing initial splitting on this source. compression_type: compression type to use @@ -77,12 +80,16 @@ def __init__(self, ValueError: when compression and splittable files are specified. IOError: when the file pattern specified yields an empty result. """ - if not isinstance(file_pattern, basestring): - raise TypeError( - '%s: file_pattern must be a string; got %r instead' % - (self.__class__.__name__, file_pattern)) + if (not isinstance(file_pattern, basestring) + or isinstance(file_pattern, ValueProvider)): + raise TypeError('%s: file_pattern must be of type string' + ' or ValueProvider; got %r instead' + % (self.__class__.__name__, file_pattern)) self._pattern = file_pattern + if isinstance(file_pattern, basestring): + file_pattern = StaticValueProvider(str, file_pattern) + self._pattern = file_pattern self._concat_source = None self._min_bundle_size = min_bundle_size if not fileio.CompressionTypes.is_valid_compression_type(compression_type): @@ -99,17 +106,20 @@ def __init__(self, self._validate() def display_data(self): - return {'file_pattern': DisplayDataItem(self._pattern, + return {'file_pattern': DisplayDataItem(str(self._pattern), label="File Pattern"), 'compression': DisplayDataItem(str(self._compression_type), label='Compression Type')} def _get_concat_source(self): if self._concat_source is None: + if not self._pattern.is_accessible(): + raise RuntimeError('value not accessible') + single_file_sources = [] - file_names = [f for f in fileio.ChannelFactory.glob(self._pattern)] - sizes = FileBasedSource._estimate_sizes_of_files(file_names, - self._pattern) + pattern = self._pattern.get() + file_names = [f for f in fileio.ChannelFactory.glob(pattern)] + sizes = FileBasedSource._estimate_sizes_of_files(file_names, pattern) # We create a reference for FileBasedSource that will be serialized along # with each _SingleFileSource. To prevent this FileBasedSource from having @@ -176,10 +186,14 @@ def _estimate_sizes_of_files(file_names, pattern=None): def _validate(self): """Validate if there are actual files in the specified glob pattern """ + if not self._pattern.is_accessible(): + raise RuntimeError('value not accessible') + pattern = self._pattern.get() + # Limit the responses as we only want to check if something exists - if len(fileio.ChannelFactory.glob(self._pattern, limit=1)) <= 0: + if len(fileio.ChannelFactory.glob(pattern, limit=1)) <= 0: raise IOError( - 'No files found based on the file pattern %s' % self._pattern) + 'No files found based on the file pattern %s' % pattern) def split( self, desired_bundle_size=None, start_position=None, stop_position=None): @@ -189,7 +203,11 @@ def split( stop_position=stop_position) def estimate_size(self): - file_names = [f for f in fileio.ChannelFactory.glob(self._pattern)] + if not self._pattern.is_accessible(): + raise RuntimeError('value not accessible') + pattern = self._pattern.get() + file_names = [f for f in fileio.ChannelFactory.glob(pattern)] + # We're reading very few files so we can pass names file names to # _estimate_sizes_of_files without pattern as otherwise we'll try to do # optimization based on the pattern and might end up reading much more diff --git a/sdks/python/apache_beam/io/filebasedsource_test.py b/sdks/python/apache_beam/io/filebasedsource_test.py index 8f12627330a7..077663041249 100644 --- a/sdks/python/apache_beam/io/filebasedsource_test.py +++ b/sdks/python/apache_beam/io/filebasedsource_test.py @@ -42,6 +42,7 @@ from apache_beam.transforms.display_test import DisplayDataItemMatcher from apache_beam.transforms.util import assert_that from apache_beam.transforms.util import equal_to +from apache_beam.utils.value_provider import StaticValueProvider class LineSource(FileBasedSource): @@ -253,8 +254,10 @@ def test_single_file_display_data(self): fbs = LineSource(file_name) dd = DisplayData.create_from(fbs) expected_items = [ - DisplayDataItemMatcher('file_pattern', file_name), + DisplayDataItemMatcher('file_pattern', + str(StaticValueProvider(str, file_name))), DisplayDataItemMatcher('compression', 'auto')] + print dd.items hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items)) @@ -586,7 +589,8 @@ def test_source_creation_display_data(self): dd = DisplayData.create_from(fbs) expected_items = [ DisplayDataItemMatcher('compression', 'auto'), - DisplayDataItemMatcher('file_pattern', file_name)] + DisplayDataItemMatcher('file_pattern', + str(StaticValueProvider(str, file_name)))] hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items)) diff --git a/sdks/python/apache_beam/io/textio_test.py b/sdks/python/apache_beam/io/textio_test.py index 877e1901d9f0..3091a2fb5686 100644 --- a/sdks/python/apache_beam/io/textio_test.py +++ b/sdks/python/apache_beam/io/textio_test.py @@ -49,6 +49,8 @@ from apache_beam.transforms.util import assert_that from apache_beam.transforms.util import equal_to +from apache_beam.utils.value_provider import StaticValueProvider + class TextSourceTest(unittest.TestCase): @@ -297,7 +299,8 @@ def test_read_display_data(self): dd = DisplayData.create_from(read) expected_items = [ DisplayDataItemMatcher('compression', 'auto'), - DisplayDataItemMatcher('file_pattern', 'prefix'), + DisplayDataItemMatcher('file_pattern', + str(StaticValueProvider(str, 'prefix'))), DisplayDataItemMatcher('strip_newline', True)] hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items)) diff --git a/sdks/python/apache_beam/transforms/display_test.py b/sdks/python/apache_beam/transforms/display_test.py index 848746c96837..c4681c313a0e 100644 --- a/sdks/python/apache_beam/transforms/display_test.py +++ b/sdks/python/apache_beam/transforms/display_test.py @@ -30,6 +30,7 @@ from apache_beam.transforms.display import DisplayData from apache_beam.transforms.display import DisplayDataItem from apache_beam.utils.pipeline_options import PipelineOptions +from apache_beam.utils.pipeline_options import static_value_provider_of class DisplayDataItemMatcher(BaseMatcher): @@ -114,6 +115,33 @@ def display_data(self): with self.assertRaises(ValueError): DisplayData.create_from_options(MyDisplayComponent()) + def test_vp_display_data(self): + class TestOptions(PipelineOptions): + @classmethod + def _add_argparse_args(cls, parser): + parser.add_argument( + '--int_flag', + type=static_value_provider_of(int), + help='int_flag description') + parser.add_argument( + '--str_flag', + type=static_value_provider_of(str), + help='str_flag description') + options = TestOptions(['--int_flag', '1', '--str_flag', '/dev/null']) + # TODO: Make flags be capable of having + # the same name for vp and non-vp values. + items = DisplayData.create_from_options(options).items + expected_items = [ + DisplayDataItemMatcher( + 'int_flag', + 'StaticValueProvider(type=int, value=1)'), + DisplayDataItemMatcher( + 'str_flag', + 'StaticValueProvider(type=str, value=\'/dev/null\')' + ) + ] + hc.assert_that(items, hc.contains_inanyorder(*expected_items)) + def test_create_list_display_data(self): flags = ['--extra_package', 'package1', '--extra_package', 'package2'] pipeline_options = PipelineOptions(flags=flags) diff --git a/sdks/python/apache_beam/utils/pipeline_options.py b/sdks/python/apache_beam/utils/pipeline_options.py index 9f57ee7ea1e2..b763fc0047f3 100644 --- a/sdks/python/apache_beam/utils/pipeline_options.py +++ b/sdks/python/apache_beam/utils/pipeline_options.py @@ -20,6 +20,20 @@ import argparse from apache_beam.transforms.display import HasDisplayData +from apache_beam.utils.value_provider import StaticValueProvider + + +def static_value_provider_of(type): + """"Helper function to plug a ValueProvider into argparse. + + Args: + type: the type of the value. Since the type param of argparse's + add_argument will always be ValueProvider, we need to + preserve the type of the actual value. + """ + def _f(value): + return StaticValueProvider(type, value) + return _f class PipelineOptions(HasDisplayData): diff --git a/sdks/python/apache_beam/utils/pipeline_options_test.py b/sdks/python/apache_beam/utils/pipeline_options_test.py index 054b6a5e4c89..cd7c9ac7fd47 100644 --- a/sdks/python/apache_beam/utils/pipeline_options_test.py +++ b/sdks/python/apache_beam/utils/pipeline_options_test.py @@ -24,6 +24,8 @@ from apache_beam.transforms.display import DisplayData from apache_beam.transforms.display_test import DisplayDataItemMatcher from apache_beam.utils.pipeline_options import PipelineOptions +from apache_beam.utils.pipeline_options import static_value_provider_of +from apache_beam.utils.value_provider import StaticValueProvider class PipelineOptionsTest(unittest.TestCase): @@ -170,6 +172,26 @@ def test_template_location(self): options = PipelineOptions(flags=['']) self.assertEqual(options.get_all_options()['template_location'], None) + def test_static_value_provider_of(self): + class TestOptions(PipelineOptions): + @classmethod + def _add_argparse_args(cls, parser): + parser.add_argument( + '--int_flag', + type=static_value_provider_of(int), + help='--int_flag description') + parser.add_argument( + '--str_flag', + type=static_value_provider_of(str), + help='--str_flag descriptions') + # TODO: Make flags be capable of having + # the same name for vp and non-vp values. + options = TestOptions(['--int_flag', '1', '--str_flag', '/dev/null']) + assert isinstance(options.int_flag, StaticValueProvider) + assert isinstance(options.str_flag, StaticValueProvider) + assert options.int_flag.get() == 1 + assert options.str_flag.get() == '/dev/null' + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main() diff --git a/sdks/python/apache_beam/utils/value_provider.py b/sdks/python/apache_beam/utils/value_provider.py new file mode 100644 index 000000000000..140dc8e12b1f --- /dev/null +++ b/sdks/python/apache_beam/utils/value_provider.py @@ -0,0 +1,50 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""A ValueProvider class to implement templates with both hard-coded +and dynamically provided values. +""" + + +class ValueProvider(object): + def is_accessible(self): + raise NotImplementedError( + 'ValueProvider.is_accessible implemented in derived classes' + ) + + def get(self): + raise NotImplementedError( + 'ValueProvider.get implemented in derived classes' + ) + + +class StaticValueProvider(object): + def __init__(self, value_class, value): + self.value_class = value_class + self.data = value_class(value) + self.accessible = True + + def is_accessible(self): + return self.accessible + + def get(self): + return self.data + + def __str__(self): + return '%s(type=%s, value=%s)' % (self.__class__.__name__, + self.value_class.__name__, + repr(self.data)) From 62b0e26169e531eb4dbe84ad5121a0679ad5d5e7 Mon Sep 17 00:00:00 2001 From: Maria Garcia Herrero Date: Sun, 8 Jan 2017 17:32:59 -0800 Subject: [PATCH 2/3] Improve tests --- sdks/python/apache_beam/io/avroio_test.py | 12 ++++++------ sdks/python/apache_beam/io/filebasedsource_test.py | 12 ++++++------ sdks/python/apache_beam/io/textio_test.py | 7 +++---- 3 files changed, 15 insertions(+), 16 deletions(-) diff --git a/sdks/python/apache_beam/io/avroio_test.py b/sdks/python/apache_beam/io/avroio_test.py index 91abe62f0683..05c3b0a6226d 100644 --- a/sdks/python/apache_beam/io/avroio_test.py +++ b/sdks/python/apache_beam/io/avroio_test.py @@ -29,7 +29,6 @@ from apache_beam.transforms.display_test import DisplayDataItemMatcher from apache_beam.transforms.util import assert_that from apache_beam.transforms.util import equal_to -from apache_beam.utils.value_provider import StaticValueProvider # Importing following private class for testing purposes. @@ -166,9 +165,9 @@ def test_source_display_data(self): # No extra avro parameters for AvroSource. expected_items = [ DisplayDataItemMatcher('compression', 'auto'), - DisplayDataItemMatcher('file_pattern', - str(StaticValueProvider(str, file_name)))] - + DisplayDataItemMatcher( + 'file_pattern', + 'StaticValueProvider(type=str, value=\'%s\')' % file_name)] hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items)) def test_read_display_data(self): @@ -179,8 +178,9 @@ def test_read_display_data(self): # No extra avro parameters for AvroSource. expected_items = [ DisplayDataItemMatcher('compression', 'auto'), - DisplayDataItemMatcher('file_pattern', - str(StaticValueProvider(str, file_name)))] + DisplayDataItemMatcher( + 'file_pattern', + 'StaticValueProvider(type=str, value=\'%s\')' % file_name)] hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items)) def test_sink_display_data(self): diff --git a/sdks/python/apache_beam/io/filebasedsource_test.py b/sdks/python/apache_beam/io/filebasedsource_test.py index 077663041249..aa3449aff65b 100644 --- a/sdks/python/apache_beam/io/filebasedsource_test.py +++ b/sdks/python/apache_beam/io/filebasedsource_test.py @@ -42,7 +42,6 @@ from apache_beam.transforms.display_test import DisplayDataItemMatcher from apache_beam.transforms.util import assert_that from apache_beam.transforms.util import equal_to -from apache_beam.utils.value_provider import StaticValueProvider class LineSource(FileBasedSource): @@ -254,10 +253,10 @@ def test_single_file_display_data(self): fbs = LineSource(file_name) dd = DisplayData.create_from(fbs) expected_items = [ - DisplayDataItemMatcher('file_pattern', - str(StaticValueProvider(str, file_name))), + DisplayDataItemMatcher( + 'file_pattern', + 'StaticValueProvider(type=str, value=\'%s\')' % file_name), DisplayDataItemMatcher('compression', 'auto')] - print dd.items hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items)) @@ -589,8 +588,9 @@ def test_source_creation_display_data(self): dd = DisplayData.create_from(fbs) expected_items = [ DisplayDataItemMatcher('compression', 'auto'), - DisplayDataItemMatcher('file_pattern', - str(StaticValueProvider(str, file_name)))] + DisplayDataItemMatcher( + 'file_pattern', + 'StaticValueProvider(type=str, value=\'%s\')' % file_name)] hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items)) diff --git a/sdks/python/apache_beam/io/textio_test.py b/sdks/python/apache_beam/io/textio_test.py index 3091a2fb5686..9965ab3e784d 100644 --- a/sdks/python/apache_beam/io/textio_test.py +++ b/sdks/python/apache_beam/io/textio_test.py @@ -49,8 +49,6 @@ from apache_beam.transforms.util import assert_that from apache_beam.transforms.util import equal_to -from apache_beam.utils.value_provider import StaticValueProvider - class TextSourceTest(unittest.TestCase): @@ -299,8 +297,9 @@ def test_read_display_data(self): dd = DisplayData.create_from(read) expected_items = [ DisplayDataItemMatcher('compression', 'auto'), - DisplayDataItemMatcher('file_pattern', - str(StaticValueProvider(str, 'prefix'))), + DisplayDataItemMatcher( + 'file_pattern', + 'StaticValueProvider(type=str, value=\'prefix\')'), DisplayDataItemMatcher('strip_newline', True)] hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items)) From 71b0a9dbc89e0cc059816d0befe8e7a7fd02607f Mon Sep 17 00:00:00 2001 From: Maria Garcia Herrero Date: Tue, 17 Jan 2017 22:57:22 -0800 Subject: [PATCH 3/3] Add RuntimeValueProvider class --- sdks/python/apache_beam/io/filebasedsource.py | 6 +- .../apache_beam/transforms/display_test.py | 19 +++--- .../apache_beam/utils/pipeline_options.py | 64 +++++++++++++++++-- .../utils/pipeline_options_test.py | 58 ++++++++++++----- .../apache_beam/utils/value_provider.py | 61 ++++++++++++++++-- 5 files changed, 165 insertions(+), 43 deletions(-) diff --git a/sdks/python/apache_beam/io/filebasedsource.py b/sdks/python/apache_beam/io/filebasedsource.py index 181675795846..d437aec17144 100644 --- a/sdks/python/apache_beam/io/filebasedsource.py +++ b/sdks/python/apache_beam/io/filebasedsource.py @@ -114,7 +114,7 @@ def display_data(self): def _get_concat_source(self): if self._concat_source is None: if not self._pattern.is_accessible(): - raise RuntimeError('value not accessible') + raise RuntimeError('%s not accessible' % self._pattern) single_file_sources = [] pattern = self._pattern.get() @@ -187,7 +187,7 @@ def _validate(self): """Validate if there are actual files in the specified glob pattern """ if not self._pattern.is_accessible(): - raise RuntimeError('value not accessible') + raise RuntimeError('%s not accessible' % self._pattern) pattern = self._pattern.get() # Limit the responses as we only want to check if something exists @@ -204,7 +204,7 @@ def split( def estimate_size(self): if not self._pattern.is_accessible(): - raise RuntimeError('value not accessible') + raise RuntimeError('%s not accessible' % self._pattern) pattern = self._pattern.get() file_names = [f for f in fileio.ChannelFactory.glob(pattern)] diff --git a/sdks/python/apache_beam/transforms/display_test.py b/sdks/python/apache_beam/transforms/display_test.py index c4681c313a0e..25bc619acb74 100644 --- a/sdks/python/apache_beam/transforms/display_test.py +++ b/sdks/python/apache_beam/transforms/display_test.py @@ -30,7 +30,6 @@ from apache_beam.transforms.display import DisplayData from apache_beam.transforms.display import DisplayDataItem from apache_beam.utils.pipeline_options import PipelineOptions -from apache_beam.utils.pipeline_options import static_value_provider_of class DisplayDataItemMatcher(BaseMatcher): @@ -115,21 +114,20 @@ def display_data(self): with self.assertRaises(ValueError): DisplayData.create_from_options(MyDisplayComponent()) - def test_vp_display_data(self): + def test_value_provider_display_data(self): class TestOptions(PipelineOptions): @classmethod def _add_argparse_args(cls, parser): - parser.add_argument( + parser.add_value_provider_argument( '--int_flag', - type=static_value_provider_of(int), + type=int, help='int_flag description') - parser.add_argument( + parser.add_value_provider_argument( '--str_flag', - type=static_value_provider_of(str), + type=str, + default='hello', help='str_flag description') - options = TestOptions(['--int_flag', '1', '--str_flag', '/dev/null']) - # TODO: Make flags be capable of having - # the same name for vp and non-vp values. + options = TestOptions(['--int_flag', '1']) items = DisplayData.create_from_options(options).items expected_items = [ DisplayDataItemMatcher( @@ -137,7 +135,8 @@ def _add_argparse_args(cls, parser): 'StaticValueProvider(type=int, value=1)'), DisplayDataItemMatcher( 'str_flag', - 'StaticValueProvider(type=str, value=\'/dev/null\')' + 'RuntimeValueProvider(option=str_flag,' + ' type=str, default_value=\'hello\', value=None)' ) ] hc.assert_that(items, hc.contains_inanyorder(*expected_items)) diff --git a/sdks/python/apache_beam/utils/pipeline_options.py b/sdks/python/apache_beam/utils/pipeline_options.py index b763fc0047f3..6eebd3a04ff8 100644 --- a/sdks/python/apache_beam/utils/pipeline_options.py +++ b/sdks/python/apache_beam/utils/pipeline_options.py @@ -21,21 +21,68 @@ from apache_beam.transforms.display import HasDisplayData from apache_beam.utils.value_provider import StaticValueProvider +from apache_beam.utils.value_provider import RuntimeValueProvider +from apache_beam.utils.value_provider import ValueProvider -def static_value_provider_of(type): +def static_value_provider_of(value_type): """"Helper function to plug a ValueProvider into argparse. Args: - type: the type of the value. Since the type param of argparse's + value_type: the type of the value. Since the type param of argparse's add_argument will always be ValueProvider, we need to preserve the type of the actual value. """ def _f(value): - return StaticValueProvider(type, value) + _f.func_name = value_type.__name__ + return StaticValueProvider(value_type, value) return _f +class ValueProviderArgumentParser(argparse.ArgumentParser): + """This class provides an API to add options of ValueProvider type. + + It preserves the functionalities of the parent ArgumentParser. + A template user willing to define parameterizable options will + only need to define a subclass of PipelineOptions in this manner: + + class TemplateUserOptions(PipelineOptions): + + @classmethod + def _add_argparse_args(cls, parser): + parser.add_value_provider_argument('--abc', default='start') + parser.add_value_provider_argument('--xyz', default='end') + """ + def add_value_provider_argument(self, *args, **kwargs): + # extract the option name. + # TODO (mariapython): handle multiple positional args like ('--quux',) + option_name = args[0].replace('-', '') + + # reassign the type to make room for StaticValueProvider + value_type = kwargs.get('type') or str + + # use StaticValueProvider as the type of the argument + kwargs['type'] = static_value_provider_of(value_type) + + # reassign default to value_default to make room for using + # RuntimeValueProvider as the default for add_argument + default_value = kwargs.get('default') + + # use RuntimeValueProvider() as the default + kwargs['default'] = RuntimeValueProvider( + pipeline_options_subclass=(self.pipeline_options_subclass + or PipelineOptions), + option_name=option_name, + value_type=value_type, + default_value=default_value, + optionsid='id' + ) + kwargs['nargs'] = '?' # make positional arguments optionally templated + + # we still want add_argument to do most of the work + self.add_argument(*args, **kwargs) + + class PipelineOptions(HasDisplayData): """Pipeline options class used as container for command line options. @@ -81,11 +128,13 @@ def __init__(self, flags=None, **kwargs): """ self._flags = flags self._all_options = kwargs - parser = argparse.ArgumentParser() + parser = ValueProviderArgumentParser() + for cls in type(self).mro(): if cls == PipelineOptions: break elif '_add_argparse_args' in cls.__dict__: + parser.pipeline_options_subclass = cls cls._add_argparse_args(parser) # The _visible_options attribute will contain only those options from the # flags (i.e., command line) that can be recognized. The _all_options @@ -130,8 +179,9 @@ def get_all_options(self, drop_default=False): Returns: Dictionary of all args and values. """ - parser = argparse.ArgumentParser() + parser = ValueProviderArgumentParser() for cls in PipelineOptions.__subclasses__(): + parser.pipeline_options_subclass = cls cls._add_argparse_args(parser) # pylint: disable=protected-access known_args, _ = parser.parse_known_args(self._flags) result = vars(known_args) @@ -140,7 +190,9 @@ def get_all_options(self, drop_default=False): for k in result.keys(): if k in self._all_options: result[k] = self._all_options[k] - if drop_default and parser.get_default(k) == result[k]: + if (drop_default and + parser.get_default(k) == result[k] and + not isinstance(parser.get_default(k), ValueProvider)): del result[k] return result diff --git a/sdks/python/apache_beam/utils/pipeline_options_test.py b/sdks/python/apache_beam/utils/pipeline_options_test.py index cd7c9ac7fd47..2a397cccae39 100644 --- a/sdks/python/apache_beam/utils/pipeline_options_test.py +++ b/sdks/python/apache_beam/utils/pipeline_options_test.py @@ -24,8 +24,8 @@ from apache_beam.transforms.display import DisplayData from apache_beam.transforms.display_test import DisplayDataItemMatcher from apache_beam.utils.pipeline_options import PipelineOptions -from apache_beam.utils.pipeline_options import static_value_provider_of from apache_beam.utils.value_provider import StaticValueProvider +from apache_beam.utils.value_provider import RuntimeValueProvider class PipelineOptionsTest(unittest.TestCase): @@ -172,25 +172,49 @@ def test_template_location(self): options = PipelineOptions(flags=['']) self.assertEqual(options.get_all_options()['template_location'], None) - def test_static_value_provider_of(self): - class TestOptions(PipelineOptions): + def test_value_provider_options(self): + class UserOptions(PipelineOptions): @classmethod def _add_argparse_args(cls, parser): + parser.add_value_provider_argument( + '--vp_arg', + help='This flag is a value provider') + + parser.add_value_provider_argument( + '--vp_arg2', + default=1, + type=int) + parser.add_argument( - '--int_flag', - type=static_value_provider_of(int), - help='--int_flag description') - parser.add_argument( - '--str_flag', - type=static_value_provider_of(str), - help='--str_flag descriptions') - # TODO: Make flags be capable of having - # the same name for vp and non-vp values. - options = TestOptions(['--int_flag', '1', '--str_flag', '/dev/null']) - assert isinstance(options.int_flag, StaticValueProvider) - assert isinstance(options.str_flag, StaticValueProvider) - assert options.int_flag.get() == 1 - assert options.str_flag.get() == '/dev/null' + '--non_vp_arg', + default=1, + type=int + ) + + # Provide values: if not provided, the option becomes of the type runtime vp + options = UserOptions(['--vp_arg', 'hello']) + self.assertIsInstance(options.vp_arg, StaticValueProvider) + self.assertIsInstance(options.vp_arg2, RuntimeValueProvider) + self.assertIsInstance(options.non_vp_arg, int) + + # Values can be overwritten + options = UserOptions(vp_arg=5, + vp_arg2=StaticValueProvider(value_type=str, + value='bye'), + non_vp_arg=RuntimeValueProvider( + pipeline_options_subclass=UserOptions, + option_name='foo', + value_type=int, + default_value=10, + optionsid='id')) + self.assertEqual(options.vp_arg, 5) + self.assertTrue(options.vp_arg2.is_accessible(), + '%s is not accessible' % options.vp_arg2) + self.assertEqual(options.vp_arg2.get(), 'bye') + self.assertEqual(options.non_vp_arg.is_accessible(), False) + + with self.assertRaises(RuntimeError): + options.non_vp_arg.get() if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) diff --git a/sdks/python/apache_beam/utils/value_provider.py b/sdks/python/apache_beam/utils/value_provider.py index 140dc8e12b1f..fa9f25e24066 100644 --- a/sdks/python/apache_beam/utils/value_provider.py +++ b/sdks/python/apache_beam/utils/value_provider.py @@ -15,7 +15,7 @@ # limitations under the License. # -"""A ValueProvider class to implement templates with both hard-coded +"""A ValueProvider class to implement templates with both statically and dynamically provided values. """ @@ -33,18 +33,65 @@ def get(self): class StaticValueProvider(object): - def __init__(self, value_class, value): - self.value_class = value_class - self.data = value_class(value) - self.accessible = True + def __init__(self, value_type, value): + self.value_type = value_type + self.data = value_type(value) def is_accessible(self): - return self.accessible + return True def get(self): return self.data def __str__(self): return '%s(type=%s, value=%s)' % (self.__class__.__name__, - self.value_class.__name__, + self.value_type.__name__, repr(self.data)) + + +class RuntimeValueProvider(ValueProvider): + options_map = {} + + def __init__(self, pipeline_options_subclass, option_name, + value_type, default_value, optionsid): + self.pipeline_options_subclass = pipeline_options_subclass + self.option_name = option_name + self.default_value = default_value + self.value_type = value_type + self.optionsid = 'id' # TODO (mariapython): remove hard-coded value + # self.optionsid = optionsid + self.data = None + + def is_accessible(self): + options = RuntimeValueProvider.options_map.get(self.optionsid) + return options is not None + + def get(self): + options = RuntimeValueProvider.options_map.get(self.optionsid) + if options is None: + # raise RuntimeError('Not called from a runtime context') + raise RuntimeError('%s.get() not called from a runtime context' %self) + result = ( + options.view_as(self.pipeline_options_subclass) + ._visible_options + .__dict__ + .get(self.option_name) + ) + value = ( + result.get() + if isinstance(result, StaticValueProvider) + else self.default_value + ) + return value + + def set_runtime_options(self, options): + RuntimeValueProvider.options_map['id'] = options + + def __str__(self): + return '%s(option=%s, type=%s, default_value=%s, value=%s)' % ( + self.__class__.__name__, + self.option_name, + self.value_type.__name__, + repr(self.default_value), + repr(self.data) + )