Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 17 additions & 14 deletions sdks/python/apache_beam/examples/wordcount.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

from __future__ import absolute_import

import argparse
import logging
import re

Expand Down Expand Up @@ -67,24 +66,28 @@ def process(self, element):

def run(argv=None):
"""Main entry point; defines and runs the wordcount pipeline."""
parser = argparse.ArgumentParser()
parser.add_argument('--input',
dest='input',
default='gs://dataflow-samples/shakespeare/kinglear.txt',
help='Input file to process.')
parser.add_argument('--output',
dest='output',
required=True,
help='Output file to write results to.')
known_args, pipeline_args = parser.parse_known_args(argv)
class WordcountOptions(PipelineOptions):
@classmethod
def _add_argparse_args(cls, parser):
parser.add_value_provider_argument(
'--input',
dest='input',
default='gs://dataflow-samples/shakespeare/kinglear.txt',
help='Input file to process.')
parser.add_value_provider_argument(
'--output',
dest='output',
required=True,
help='Output file to write results to.')
pipeline_options = PipelineOptions(argv)
wordcount_options = pipeline_options.view_as(WordcountOptions)
# We use the save_main_session option because one or more DoFn's in this
# workflow rely on global context (e.g., a module imported at module level).
pipeline_options = PipelineOptions(pipeline_args)
pipeline_options.view_as(SetupOptions).save_main_session = True
p = beam.Pipeline(options=pipeline_options)

# Read the text file[pattern] into a PCollection.
lines = p | 'read' >> ReadFromText(known_args.input)
lines = p | 'read' >> ReadFromText(wordcount_options.input)

# Count the occurrences of each word.
counts = (lines
Expand All @@ -99,7 +102,7 @@ def run(argv=None):

# Write the output using a "Write" transform that has side effects.
# pylint: disable=expression-not-assigned
output | 'write' >> WriteToText(known_args.output)
output | 'write' >> WriteToText(wordcount_options.output)

# Actually run the pipeline (all operations above are deferred).
result = p.run()
Expand Down
6 changes: 6 additions & 0 deletions sdks/python/apache_beam/internal/gcp/json_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
extra_types = None
# pylint: enable=wrong-import-order, wrong-import-position

from apache_beam.utils.value_provider import ValueProvider


_MAXINT64 = (1 << 63) - 1
_MININT64 = - (1 << 63)
Expand Down Expand Up @@ -104,6 +106,10 @@ def to_json_value(obj, with_type=False):
raise TypeError('Can not encode {} as a 64-bit integer'.format(obj))
elif isinstance(obj, float):
return extra_types.JsonValue(double_value=obj)
elif isinstance(obj, ValueProvider):
if obj.is_accessible():
return to_json_value(obj.get())
return extra_types.JsonValue(is_null=True)
else:
raise TypeError('Cannot convert %s to a JSON value.' % repr(obj))

Expand Down
46 changes: 32 additions & 14 deletions sdks/python/apache_beam/io/filebasedsource.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@
from apache_beam.io import iobase
from apache_beam.io import range_trackers
from apache_beam.transforms.display import DisplayDataItem
from apache_beam.utils.value_provider import ValueProvider
from apache_beam.utils.value_provider import StaticValueProvider
from apache_beam.utils.value_provider import check_accessible


MAX_NUM_THREADS_FOR_SIZE_ESTIMATION = 25

Expand All @@ -53,7 +57,8 @@ def __init__(self,
"""Initializes ``FileBasedSource``.

Args:
file_pattern: the file glob to read.
file_pattern: the file glob to read a string or a ValueProvider
(placeholder to inject a runtime value).
min_bundle_size: minimum size of bundles that should be generated when
performing initial splitting on this source.
compression_type: compression type to use
Expand All @@ -71,15 +76,19 @@ def __init__(self,
creation time.
Raises:
TypeError: when compression_type is not valid or if file_pattern is not a
string.
string or a ValueProvider.
ValueError: when compression and splittable files are specified.
IOError: when the file pattern specified yields an empty result.
"""
if not isinstance(file_pattern, basestring):
raise TypeError(
'%s: file_pattern must be a string; got %r instead' %
(self.__class__.__name__, file_pattern))

if (not (isinstance(file_pattern, basestring)
or isinstance(file_pattern, ValueProvider))):
raise TypeError('%s: file_pattern must be of type string'
' or ValueProvider; got %r instead'
% (self.__class__.__name__, file_pattern))

if isinstance(file_pattern, basestring):
file_pattern = StaticValueProvider(str, file_pattern)
self._pattern = file_pattern
self._concat_source = None
self._min_bundle_size = min_bundle_size
Expand All @@ -93,21 +102,24 @@ def __init__(self,
else:
# We can't split compressed files efficiently so turn off splitting.
self._splittable = False
if validate:
if validate and file_pattern.is_accessible():
self._validate()

def display_data(self):
return {'file_pattern': DisplayDataItem(self._pattern,
return {'file_pattern': DisplayDataItem(str(self._pattern),
label="File Pattern"),
'compression': DisplayDataItem(str(self._compression_type),
label='Compression Type')}

@check_accessible(['_pattern'])
def _get_concat_source(self):
if self._concat_source is None:
pattern = self._pattern.get()

single_file_sources = []
file_names = [f for f in fileio.ChannelFactory.glob(self._pattern)]
file_names = [f for f in fileio.ChannelFactory.glob(pattern)]
sizes = FileBasedSource._estimate_sizes_of_files(file_names,
self._pattern)
pattern)

# We create a reference for FileBasedSource that will be serialized along
# with each _SingleFileSource. To prevent this FileBasedSource from having
Expand Down Expand Up @@ -164,13 +176,16 @@ def _estimate_sizes_of_files(file_names, pattern=None):
file_names)
return [file_sizes[f] for f in file_names]

@check_accessible(['_pattern'])
def _validate(self):
"""Validate if there are actual files in the specified glob pattern
"""
pattern = self._pattern.get()

# Limit the responses as we only want to check if something exists
if len(fileio.ChannelFactory.glob(self._pattern, limit=1)) <= 0:
if len(fileio.ChannelFactory.glob(pattern, limit=1)) <= 0:
raise IOError(
'No files found based on the file pattern %s' % self._pattern)
'No files found based on the file pattern %s' % pattern)

def split(
self, desired_bundle_size=None, start_position=None, stop_position=None):
Expand All @@ -179,8 +194,11 @@ def split(
start_position=start_position,
stop_position=stop_position)

@check_accessible(['_pattern'])
def estimate_size(self):
file_names = [f for f in fileio.ChannelFactory.glob(self._pattern)]
pattern = self._pattern.get()
file_names = [f for f in fileio.ChannelFactory.glob(pattern)]

# We're reading very few files so we can pass names file names to
# _estimate_sizes_of_files without pattern as otherwise we'll try to do
# optimization based on the pattern and might end up reading much more
Expand Down Expand Up @@ -221,7 +239,7 @@ def read_records(self, file_name, offset_range_tracker):
defined by a given ``RangeTracker``.

Returns:
a iterator that gives the records read from the given file.
an iterator that gives the records read from the given file.
"""
raise NotImplementedError

Expand Down
28 changes: 27 additions & 1 deletion sdks/python/apache_beam/io/filebasedsource_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
from apache_beam.transforms.display_test import DisplayDataItemMatcher
from apache_beam.transforms.util import assert_that
from apache_beam.transforms.util import equal_to
from apache_beam.utils.value_provider import StaticValueProvider
from apache_beam.utils.value_provider import RuntimeValueProvider


class LineSource(FileBasedSource):
Expand Down Expand Up @@ -221,6 +223,28 @@ def setUp(self):
# environments with limited amount of resources.
filebasedsource.MAX_NUM_THREADS_FOR_SIZE_ESTIMATION = 2

def test_string_or_value_provider_only(self):
str_file_pattern = tempfile.NamedTemporaryFile(delete=False).name
self.assertEqual(str_file_pattern,
FileBasedSource(str_file_pattern)._pattern.value)

static_vp_file_pattern = StaticValueProvider(value_type=str,
value=str_file_pattern)
self.assertEqual(static_vp_file_pattern,
FileBasedSource(static_vp_file_pattern)._pattern)

runtime_vp_file_pattern = RuntimeValueProvider(
pipeline_options_subclass=object,
option_name='blah',
value_type=str,
default_value=str_file_pattern)
self.assertEqual(runtime_vp_file_pattern,
FileBasedSource(runtime_vp_file_pattern)._pattern)

invalid_file_pattern = 123
with self.assertRaises(TypeError):
FileBasedSource(invalid_file_pattern)

def test_validation_file_exists(self):
file_name, _ = write_data(10)
LineSource(file_name)
Expand Down Expand Up @@ -587,7 +611,9 @@ def test_source_creation_display_data(self):
dd = DisplayData.create_from(fbs)
expected_items = [
DisplayDataItemMatcher('compression', 'auto'),
DisplayDataItemMatcher('file_pattern', file_name)]
DisplayDataItemMatcher(
'file_pattern',
file_name)]
hc.assert_that(dd.items,
hc.contains_inanyorder(*expected_items))

Expand Down
44 changes: 31 additions & 13 deletions sdks/python/apache_beam/io/fileio.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""File-based sources and sinks."""

from __future__ import absolute_import
Expand All @@ -31,6 +32,9 @@
from apache_beam.internal import util
from apache_beam.io import iobase
from apache_beam.transforms.display import DisplayDataItem
from apache_beam.utils.value_provider import ValueProvider
from apache_beam.utils.value_provider import StaticValueProvider
from apache_beam.utils.value_provider import check_accessible

# TODO(sourabhbajaj): Fix the constant values after the new IO factory
# Current constants are copy pasted from gcsio.py till we fix this.
Expand Down Expand Up @@ -544,25 +548,30 @@ def __init__(self,
compression_type=CompressionTypes.AUTO):
"""
Raises:
TypeError: if file path parameters are not a string or if compression_type
is not member of CompressionTypes.
TypeError: if file path parameters are not a string or ValueProvider,
or if compression_type is not member of CompressionTypes.
ValueError: if shard_name_template is not of expected format.
"""
if not isinstance(file_path_prefix, basestring):
raise TypeError('file_path_prefix must be a string; got %r instead' %
file_path_prefix)
if not isinstance(file_name_suffix, basestring):
raise TypeError('file_name_suffix must be a string; got %r instead' %
file_name_suffix)
if not (isinstance(file_path_prefix, basestring)
or isinstance(file_path_prefix, ValueProvider)):
raise TypeError('file_path_prefix must be a string or ValueProvider;'
'got %r instead' % file_path_prefix)
if not (isinstance(file_name_suffix, basestring)
or isinstance(file_name_suffix, ValueProvider)):
raise TypeError('file_name_suffix must be a string or ValueProvider;'
'got %r instead' % file_name_suffix)

if not CompressionTypes.is_valid_compression_type(compression_type):
raise TypeError('compression_type must be CompressionType object but '
'was %s' % type(compression_type))

if shard_name_template is None:
shard_name_template = DEFAULT_SHARD_NAME_TEMPLATE
elif shard_name_template is '':
num_shards = 1
if isinstance(file_path_prefix, basestring):
file_path_prefix = StaticValueProvider(str, file_path_prefix)
if isinstance(file_name_suffix, basestring):
file_name_suffix = StaticValueProvider(str, file_name_suffix)
self.file_path_prefix = file_path_prefix
self.file_name_suffix = file_name_suffix
self.num_shards = num_shards
Expand Down Expand Up @@ -618,22 +627,31 @@ def close(self, file_handle):
if file_handle is not None:
file_handle.close()

@check_accessible(['file_path_prefix', 'file_name_suffix'])
def initialize_write(self):
tmp_dir = self.file_path_prefix + self.file_name_suffix + time.strftime(
file_path_prefix = self.file_path_prefix.get()
file_name_suffix = self.file_name_suffix.get()
tmp_dir = file_path_prefix + file_name_suffix + time.strftime(
'-temp-%Y-%m-%d_%H-%M-%S')
ChannelFactory().mkdir(tmp_dir)
return tmp_dir

@check_accessible(['file_path_prefix', 'file_name_suffix'])
def open_writer(self, init_result, uid):
# A proper suffix is needed for AUTO compression detection.
# We also ensure there will be no collisions with uid and a
# (possibly unsharded) file_path_prefix and a (possibly empty)
# file_name_suffix.
file_path_prefix = self.file_path_prefix.get()
file_name_suffix = self.file_name_suffix.get()
suffix = (
'.' + os.path.basename(self.file_path_prefix) + self.file_name_suffix)
'.' + os.path.basename(file_path_prefix) + file_name_suffix)
return FileSinkWriter(self, os.path.join(init_result, uid) + suffix)

@check_accessible(['file_path_prefix', 'file_name_suffix'])
def finalize_write(self, init_result, writer_results):
file_path_prefix = self.file_path_prefix.get()
file_name_suffix = self.file_name_suffix.get()
writer_results = sorted(writer_results)
num_shards = len(writer_results)
min_threads = min(num_shards, FileSink._MAX_RENAME_THREADS)
Expand All @@ -642,8 +660,8 @@ def finalize_write(self, init_result, writer_results):
rename_ops = []
for shard_num, shard in enumerate(writer_results):
final_name = ''.join([
self.file_path_prefix, self.shard_name_format % dict(
shard_num=shard_num, num_shards=num_shards), self.file_name_suffix
file_path_prefix, self.shard_name_format % dict(
shard_num=shard_num, num_shards=num_shards), file_name_suffix
])
rename_ops.append((shard, final_name))

Expand Down
Loading