Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BEAM-802] Python templates #2545

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions sdks/python/apache_beam/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ class RunnerError(BeamError):
"""An error related to a Runner object (e.g. cannot find a runner to run)."""


class RuntimeValueProviderError(RuntimeError):
"""An error related to a ValueProvider object raised during runtime."""


class SideInputError(BeamError):
"""An error related to a side input to a parallel Do operation."""

Expand Down
32 changes: 18 additions & 14 deletions sdks/python/apache_beam/examples/wordcount.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

from __future__ import absolute_import

import argparse
import logging
import re

Expand Down Expand Up @@ -67,24 +66,29 @@ def process(self, element):

def run(argv=None):
"""Main entry point; defines and runs the wordcount pipeline."""
parser = argparse.ArgumentParser()
parser.add_argument('--input',
dest='input',
default='gs://dataflow-samples/shakespeare/kinglear.txt',
help='Input file to process.')
parser.add_argument('--output',
dest='output',
required=True,
help='Output file to write results to.')
known_args, pipeline_args = parser.parse_known_args(argv)
class WordcountOptions(PipelineOptions):
@classmethod
def _add_argparse_args(cls, parser):
parser.add_value_provider_argument(
'--input',
dest='input',
default='gs://dataflow-samples/shakespeare/kinglear.txt',
help='Input file to process.')
parser.add_value_provider_argument(
'--output',
dest='output',
required=True,
help='Output file to write results to.')
pipeline_options = PipelineOptions(argv)
wordcount_options = pipeline_options.view_as(WordcountOptions)

# We use the save_main_session option because one or more DoFn's in this
# workflow rely on global context (e.g., a module imported at module level).
pipeline_options = PipelineOptions(pipeline_args)
pipeline_options.view_as(SetupOptions).save_main_session = True
p = beam.Pipeline(options=pipeline_options)

# Read the text file[pattern] into a PCollection.
lines = p | 'read' >> ReadFromText(known_args.input)
lines = p | 'read' >> ReadFromText(wordcount_options.input)

# Count the occurrences of each word.
counts = (lines
Expand All @@ -99,7 +103,7 @@ def run(argv=None):

# Write the output using a "Write" transform that has side effects.
# pylint: disable=expression-not-assigned
output | 'write' >> WriteToText(known_args.output)
output | 'write' >> WriteToText(wordcount_options.output)

# Actually run the pipeline (all operations above are deferred).
result = p.run()
Expand Down
6 changes: 6 additions & 0 deletions sdks/python/apache_beam/internal/gcp/json_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
extra_types = None
# pylint: enable=wrong-import-order, wrong-import-position

from apache_beam.utils.value_provider import ValueProvider


_MAXINT64 = (1 << 63) - 1
_MININT64 = - (1 << 63)
Expand Down Expand Up @@ -104,6 +106,10 @@ def to_json_value(obj, with_type=False):
raise TypeError('Can not encode {} as a 64-bit integer'.format(obj))
elif isinstance(obj, float):
return extra_types.JsonValue(double_value=obj)
elif isinstance(obj, ValueProvider):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update json_values_tests to have this case ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if obj.is_accessible():
return to_json_value(obj.get())
return extra_types.JsonValue(is_null=True)
else:
raise TypeError('Cannot convert %s to a JSON value.' % repr(obj))

Expand Down
53 changes: 39 additions & 14 deletions sdks/python/apache_beam/io/filebasedsource.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
from apache_beam.io.filesystem import CompressionTypes
from apache_beam.io.filesystems_util import get_filesystem
from apache_beam.transforms.display import DisplayDataItem
from apache_beam.utils.value_provider import ValueProvider
from apache_beam.utils.value_provider import StaticValueProvider
from apache_beam.utils.value_provider import check_accessible

MAX_NUM_THREADS_FOR_SIZE_ESTIMATION = 25

Expand All @@ -51,7 +54,8 @@ def __init__(self,
"""Initializes ``FileBasedSource``.

Args:
file_pattern: the file glob to read.
file_pattern: the file glob to read a string or a ValueProvider
(placeholder to inject a runtime value).
min_bundle_size: minimum size of bundles that should be generated when
performing initial splitting on this source.
compression_type: compression type to use
Expand All @@ -69,17 +73,24 @@ def __init__(self,
creation time.
Raises:
TypeError: when compression_type is not valid or if file_pattern is not a
string.
string or a ValueProvider.
ValueError: when compression and splittable files are specified.
IOError: when the file pattern specified yields an empty result.
"""
if not isinstance(file_pattern, basestring):
raise TypeError(
'%s: file_pattern must be a string; got %r instead' %
(self.__class__.__name__, file_pattern))

if not isinstance(file_pattern, (basestring, ValueProvider)):
raise TypeError('%s: file_pattern must be of type string'
' or ValueProvider; got %r instead'
% (self.__class__.__name__, file_pattern))

if isinstance(file_pattern, basestring):
file_pattern = StaticValueProvider(str, file_pattern)
self._pattern = file_pattern
self._file_system = get_filesystem(file_pattern)
if file_pattern.is_accessible():
self._file_system = get_filesystem(file_pattern.get())
else:
self._file_system = None

self._concat_source = None
self._min_bundle_size = min_bundle_size
if not CompressionTypes.is_valid_compression_type(compression_type):
Expand All @@ -92,19 +103,24 @@ def __init__(self,
else:
# We can't split compressed files efficiently so turn off splitting.
self._splittable = False
if validate:
if validate and file_pattern.is_accessible():
self._validate()

def display_data(self):
return {'file_pattern': DisplayDataItem(self._pattern,
return {'file_pattern': DisplayDataItem(str(self._pattern),
label="File Pattern"),
'compression': DisplayDataItem(str(self._compression_type),
label='Compression Type')}

@check_accessible(['_pattern'])
def _get_concat_source(self):
if self._concat_source is None:
pattern = self._pattern.get()

single_file_sources = []
match_result = self._file_system.match([self._pattern])[0]
if self._file_system is None:
self._file_system = get_filesystem(pattern)
match_result = self._file_system.match([pattern])[0]
files_metadata = match_result.metadata_list

# We create a reference for FileBasedSource that will be serialized along
Expand Down Expand Up @@ -143,14 +159,19 @@ def open_file(self, file_name):
file_name, 'application/octet-stream',
compression_type=self._compression_type)

@check_accessible(['_pattern'])
def _validate(self):
"""Validate if there are actual files in the specified glob pattern
"""
pattern = self._pattern.get()
if self._file_system is None:
self._file_system = get_filesystem(pattern)

# Limit the responses as we only want to check if something exists
match_result = self._file_system.match([self._pattern], limits=[1])[0]
match_result = self._file_system.match([pattern], limits=[1])[0]
if len(match_result.metadata_list) <= 0:
raise IOError(
'No files found based on the file pattern %s' % self._pattern)
'No files found based on the file pattern %s' % pattern)

def split(
self, desired_bundle_size=None, start_position=None, stop_position=None):
Expand All @@ -159,8 +180,12 @@ def split(
start_position=start_position,
stop_position=stop_position)

@check_accessible(['_pattern'])
def estimate_size(self):
match_result = self._file_system.match([self._pattern])[0]
pattern = self._pattern.get()
if self._file_system is None:
self._file_system = get_filesystem(pattern)
match_result = self._file_system.match([pattern])[0]
return sum([f.size_in_bytes for f in match_result.metadata_list])

def read(self, range_tracker):
Expand All @@ -185,7 +210,7 @@ def read_records(self, file_name, offset_range_tracker):
defined by a given ``RangeTracker``.

Returns:
a iterator that gives the records read from the given file.
an iterator that gives the records read from the given file.
"""
raise NotImplementedError

Expand Down
23 changes: 23 additions & 0 deletions sdks/python/apache_beam/io/filebasedsource_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
from apache_beam.transforms.display_test import DisplayDataItemMatcher
from apache_beam.transforms.util import assert_that
from apache_beam.transforms.util import equal_to
from apache_beam.utils.value_provider import StaticValueProvider
from apache_beam.utils.value_provider import RuntimeValueProvider


class LineSource(FileBasedSource):
Expand Down Expand Up @@ -221,6 +223,27 @@ def setUp(self):
# environments with limited amount of resources.
filebasedsource.MAX_NUM_THREADS_FOR_SIZE_ESTIMATION = 2

def test_string_or_value_provider_only(self):
str_file_pattern = tempfile.NamedTemporaryFile(delete=False).name
self.assertEqual(str_file_pattern,
FileBasedSource(str_file_pattern)._pattern.value)

static_vp_file_pattern = StaticValueProvider(value_type=str,
value=str_file_pattern)
self.assertEqual(static_vp_file_pattern,
FileBasedSource(static_vp_file_pattern)._pattern)

runtime_vp_file_pattern = RuntimeValueProvider(
option_name='arg',
value_type=str,
default_value=str_file_pattern)
self.assertEqual(runtime_vp_file_pattern,
FileBasedSource(runtime_vp_file_pattern)._pattern)

invalid_file_pattern = 123
with self.assertRaises(TypeError):
FileBasedSource(invalid_file_pattern)

def test_validation_file_exists(self):
file_name, _ = write_data(10)
LineSource(file_name)
Expand Down
54 changes: 40 additions & 14 deletions sdks/python/apache_beam/io/fileio.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""File-based sources and sinks."""

from __future__ import absolute_import
Expand All @@ -30,6 +31,9 @@
from apache_beam.io.filesystem import CompressionTypes
from apache_beam.io.filesystems_util import get_filesystem
from apache_beam.transforms.display import DisplayDataItem
from apache_beam.utils.value_provider import ValueProvider
from apache_beam.utils.value_provider import StaticValueProvider
from apache_beam.utils.value_provider import check_accessible

DEFAULT_SHARD_NAME_TEMPLATE = '-SSSSS-of-NNNNN'

Expand Down Expand Up @@ -148,33 +152,39 @@ def __init__(self,
compression_type=CompressionTypes.AUTO):
"""
Raises:
TypeError: if file path parameters are not a string or if compression_type
is not member of CompressionTypes.
TypeError: if file path parameters are not a string or ValueProvider,
or if compression_type is not member of CompressionTypes.
ValueError: if shard_name_template is not of expected format.
"""
if not isinstance(file_path_prefix, basestring):
raise TypeError('file_path_prefix must be a string; got %r instead' %
file_path_prefix)
if not isinstance(file_name_suffix, basestring):
raise TypeError('file_name_suffix must be a string; got %r instead' %
file_name_suffix)
if not isinstance(file_path_prefix, (basestring, ValueProvider)):
raise TypeError('file_path_prefix must be a string or ValueProvider;'
'got %r instead' % file_path_prefix)
if not isinstance(file_name_suffix, (basestring, ValueProvider)):
raise TypeError('file_name_suffix must be a string or ValueProvider;'
'got %r instead' % file_name_suffix)

if not CompressionTypes.is_valid_compression_type(compression_type):
raise TypeError('compression_type must be CompressionType object but '
'was %s' % type(compression_type))

if shard_name_template is None:
shard_name_template = DEFAULT_SHARD_NAME_TEMPLATE
elif shard_name_template == '':
num_shards = 1
if isinstance(file_path_prefix, basestring):
file_path_prefix = StaticValueProvider(str, file_path_prefix)
if isinstance(file_name_suffix, basestring):
file_name_suffix = StaticValueProvider(str, file_name_suffix)
self.file_path_prefix = file_path_prefix
self.file_name_suffix = file_name_suffix
self.num_shards = num_shards
self.coder = coder
self.shard_name_format = self._template_to_format(shard_name_template)
self.compression_type = compression_type
self.mime_type = mime_type
self._file_system = get_filesystem(file_path_prefix)
if file_path_prefix.is_accessible():
self._file_system = get_filesystem(file_path_prefix.get())
else:
self._file_system = None

def display_data(self):
return {'shards':
Expand All @@ -188,12 +198,15 @@ def display_data(self):
self.file_name_suffix),
label='File Pattern')}

@check_accessible(['file_path_prefix'])
def open(self, temp_path):
"""Opens ``temp_path``, returning an opaque file handle object.

The returned file handle is passed to ``write_[encoded_]record`` and
``close``.
"""
if self._file_system is None:
self._file_system = get_filesystem(self.file_path_prefix.get())
return self._file_system.create(temp_path, self.mime_type,
self.compression_type)

Expand All @@ -220,22 +233,33 @@ def close(self, file_handle):
if file_handle is not None:
file_handle.close()

@check_accessible(['file_path_prefix', 'file_name_suffix'])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sammcveety random question:
I wanted to learn more about the accessibility check in the value providers:

Pros:

  1. Check is done before any user code is executed so we can minimize any side effects.

Cons:

  1. More change to existing user code and the end user needs to make sure they are using the correct strings in the decorator.
  2. We could just throw a RunTimeException in the get as that minimizes the concepts that the user needs to be aware of when writing the pipeline.

The user can already just ignore the check if they want so I guess it is already sort of optional.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Java SDK does not have a concept similar to check_accesible decorators. The common API across both languages are is_accesible and get, where the second one will throw a RuntimeException if the underlying value is not accesible (i.e. called at pipeline building time).

Using check_accessible is optional. A PTransform author can build a templatable transform without using it. Another pro is, it improves the readability by clearly marking functions that depend on run time values of some value.

def initialize_write(self):
tmp_dir = self.file_path_prefix + self.file_name_suffix + time.strftime(
file_path_prefix = self.file_path_prefix.get()
file_name_suffix = self.file_name_suffix.get()
tmp_dir = file_path_prefix + file_name_suffix + time.strftime(
'-temp-%Y-%m-%d_%H-%M-%S')
if self._file_system is None:
self._file_system = get_filesystem(file_path_prefix)
self._file_system.mkdirs(tmp_dir)
return tmp_dir

@check_accessible(['file_path_prefix', 'file_name_suffix'])
def open_writer(self, init_result, uid):
# A proper suffix is needed for AUTO compression detection.
# We also ensure there will be no collisions with uid and a
# (possibly unsharded) file_path_prefix and a (possibly empty)
# file_name_suffix.
file_path_prefix = self.file_path_prefix.get()
file_name_suffix = self.file_name_suffix.get()
suffix = (
'.' + os.path.basename(self.file_path_prefix) + self.file_name_suffix)
'.' + os.path.basename(file_path_prefix) + file_name_suffix)
return FileSinkWriter(self, os.path.join(init_result, uid) + suffix)

@check_accessible(['file_path_prefix', 'file_name_suffix'])
def finalize_write(self, init_result, writer_results):
file_path_prefix = self.file_path_prefix.get()
file_name_suffix = self.file_name_suffix.get()
writer_results = sorted(writer_results)
num_shards = len(writer_results)
min_threads = min(num_shards, FileSink._MAX_RENAME_THREADS)
Expand All @@ -246,8 +270,8 @@ def finalize_write(self, init_result, writer_results):
chunk_size = self._file_system.CHUNK_SIZE
for shard_num, shard in enumerate(writer_results):
final_name = ''.join([
self.file_path_prefix, self.shard_name_format % dict(
shard_num=shard_num, num_shards=num_shards), self.file_name_suffix
file_path_prefix, self.shard_name_format % dict(
shard_num=shard_num, num_shards=num_shards), file_name_suffix
])
source_files.append(shard)
destination_files.append(final_name)
Expand All @@ -270,6 +294,8 @@ def _rename_batch(batch):
"""_rename_batch executes batch rename operations."""
source_files, destination_files = batch
exceptions = []
if self._file_system is None:
self._file_system = get_filesystem(file_path_prefix)
try:
self._file_system.rename(source_files, destination_files)
return exceptions
Expand Down