Skip to content

Commit

Permalink
Revert "Revert "Revert "Add ValueProvider class for FileBasedSource I…
Browse files Browse the repository at this point in the history
…/O Transforms"""

This reverts commit 28a0ea8.

Manually resolved Conflicts:
sdks/python/apache_beam/io/filebasedsource.py
	sdks/python/apache_beam/runners/direct/direct_runner.py
  • Loading branch information
aaltay committed Apr 12, 2017
1 parent 4854291 commit 0f5c363
Show file tree
Hide file tree
Showing 14 changed files with 65 additions and 618 deletions.
32 changes: 14 additions & 18 deletions sdks/python/apache_beam/examples/wordcount.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from __future__ import absolute_import

import argparse
import logging
import re

Expand Down Expand Up @@ -66,29 +67,24 @@ def process(self, element):

def run(argv=None):
"""Main entry point; defines and runs the wordcount pipeline."""
class WordcountOptions(PipelineOptions):
@classmethod
def _add_argparse_args(cls, parser):
parser.add_value_provider_argument(
'--input',
dest='input',
default='gs://dataflow-samples/shakespeare/kinglear.txt',
help='Input file to process.')
parser.add_value_provider_argument(
'--output',
dest='output',
required=True,
help='Output file to write results to.')
pipeline_options = PipelineOptions(argv)
wordcount_options = pipeline_options.view_as(WordcountOptions)

parser = argparse.ArgumentParser()
parser.add_argument('--input',
dest='input',
default='gs://dataflow-samples/shakespeare/kinglear.txt',
help='Input file to process.')
parser.add_argument('--output',
dest='output',
required=True,
help='Output file to write results to.')
known_args, pipeline_args = parser.parse_known_args(argv)
# We use the save_main_session option because one or more DoFn's in this
# workflow rely on global context (e.g., a module imported at module level).
pipeline_options = PipelineOptions(pipeline_args)
pipeline_options.view_as(SetupOptions).save_main_session = True
p = beam.Pipeline(options=pipeline_options)

# Read the text file[pattern] into a PCollection.
lines = p | 'read' >> ReadFromText(wordcount_options.input)
lines = p | 'read' >> ReadFromText(known_args.input)

# Count the occurrences of each word.
counts = (lines
Expand All @@ -103,7 +99,7 @@ def _add_argparse_args(cls, parser):

# Write the output using a "Write" transform that has side effects.
# pylint: disable=expression-not-assigned
output | 'write' >> WriteToText(wordcount_options.output)
output | 'write' >> WriteToText(known_args.output)

# Actually run the pipeline (all operations above are deferred).
result = p.run()
Expand Down
6 changes: 0 additions & 6 deletions sdks/python/apache_beam/internal/gcp/json_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@
extra_types = None
# pylint: enable=wrong-import-order, wrong-import-position

from apache_beam.utils.value_provider import ValueProvider


_MAXINT64 = (1 << 63) - 1
_MININT64 = - (1 << 63)
Expand Down Expand Up @@ -106,10 +104,6 @@ def to_json_value(obj, with_type=False):
raise TypeError('Can not encode {} as a 64-bit integer'.format(obj))
elif isinstance(obj, float):
return extra_types.JsonValue(double_value=obj)
elif isinstance(obj, ValueProvider):
if obj.is_accessible():
return to_json_value(obj.get())
return extra_types.JsonValue(is_null=True)
else:
raise TypeError('Cannot convert %s to a JSON value.' % repr(obj))

Expand Down
54 changes: 14 additions & 40 deletions sdks/python/apache_beam/io/filebasedsource.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,6 @@
from apache_beam.io.filesystem import CompressionTypes
from apache_beam.io.filesystems_util import get_filesystem
from apache_beam.transforms.display import DisplayDataItem
from apache_beam.utils.value_provider import ValueProvider
from apache_beam.utils.value_provider import StaticValueProvider
from apache_beam.utils.value_provider import check_accessible

MAX_NUM_THREADS_FOR_SIZE_ESTIMATION = 25

Expand All @@ -54,8 +51,7 @@ def __init__(self,
"""Initializes ``FileBasedSource``.
Args:
file_pattern: the file glob to read a string or a ValueProvider
(placeholder to inject a runtime value).
file_pattern: the file glob to read.
min_bundle_size: minimum size of bundles that should be generated when
performing initial splitting on this source.
compression_type: compression type to use
Expand All @@ -73,25 +69,17 @@ def __init__(self,
creation time.
Raises:
TypeError: when compression_type is not valid or if file_pattern is not a
string or a ValueProvider.
string.
ValueError: when compression and splittable files are specified.
IOError: when the file pattern specified yields an empty result.
"""
if not isinstance(file_pattern, basestring):
raise TypeError(
'%s: file_pattern must be a string; got %r instead' %
(self.__class__.__name__, file_pattern))

if (not (isinstance(file_pattern, basestring)
or isinstance(file_pattern, ValueProvider))):
raise TypeError('%s: file_pattern must be of type string'
' or ValueProvider; got %r instead'
% (self.__class__.__name__, file_pattern))

if isinstance(file_pattern, basestring):
file_pattern = StaticValueProvider(str, file_pattern)
self._pattern = file_pattern
if file_pattern.is_accessible():
self._file_system = get_filesystem(file_pattern.get())
else:
self._file_system = None

self._file_system = get_filesystem(file_pattern)
self._concat_source = None
self._min_bundle_size = min_bundle_size
if not CompressionTypes.is_valid_compression_type(compression_type):
Expand All @@ -104,24 +92,19 @@ def __init__(self,
else:
# We can't split compressed files efficiently so turn off splitting.
self._splittable = False
if validate and file_pattern.is_accessible():
if validate:
self._validate()

def display_data(self):
return {'file_pattern': DisplayDataItem(str(self._pattern),
return {'file_pattern': DisplayDataItem(self._pattern,
label="File Pattern"),
'compression': DisplayDataItem(str(self._compression_type),
label='Compression Type')}

@check_accessible(['_pattern'])
def _get_concat_source(self):
if self._concat_source is None:
pattern = self._pattern.get()

single_file_sources = []
if self._file_system is None:
self._file_system = get_filesystem(pattern)
match_result = self._file_system.match([pattern])[0]
match_result = self._file_system.match([self._pattern])[0]
files_metadata = match_result.metadata_list

# We create a reference for FileBasedSource that will be serialized along
Expand Down Expand Up @@ -160,19 +143,14 @@ def open_file(self, file_name):
file_name, 'application/octet-stream',
compression_type=self._compression_type)

@check_accessible(['_pattern'])
def _validate(self):
"""Validate if there are actual files in the specified glob pattern
"""
pattern = self._pattern.get()
if self._file_system is None:
self._file_system = get_filesystem(pattern)

# Limit the responses as we only want to check if something exists
match_result = self._file_system.match([pattern], limits=[1])[0]
match_result = self._file_system.match([self._pattern], limits=[1])[0]
if len(match_result.metadata_list) <= 0:
raise IOError(
'No files found based on the file pattern %s' % pattern)
'No files found based on the file pattern %s' % self._pattern)

def split(
self, desired_bundle_size=None, start_position=None, stop_position=None):
Expand All @@ -181,12 +159,8 @@ def split(
start_position=start_position,
stop_position=stop_position)

@check_accessible(['_pattern'])
def estimate_size(self):
pattern = self._pattern.get()
if self._file_system is None:
self._file_system = get_filesystem(pattern)
match_result = self._file_system.match([pattern])[0]
match_result = self._file_system.match([self._pattern])[0]
return sum([f.size_in_bytes for f in match_result.metadata_list])

def read(self, range_tracker):
Expand All @@ -211,7 +185,7 @@ def read_records(self, file_name, offset_range_tracker):
defined by a given ``RangeTracker``.
Returns:
an iterator that gives the records read from the given file.
a iterator that gives the records read from the given file.
"""
raise NotImplementedError

Expand Down
24 changes: 0 additions & 24 deletions sdks/python/apache_beam/io/filebasedsource_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@
from apache_beam.transforms.display_test import DisplayDataItemMatcher
from apache_beam.transforms.util import assert_that
from apache_beam.transforms.util import equal_to
from apache_beam.utils.value_provider import StaticValueProvider
from apache_beam.utils.value_provider import RuntimeValueProvider


class LineSource(FileBasedSource):
Expand Down Expand Up @@ -223,28 +221,6 @@ def setUp(self):
# environments with limited amount of resources.
filebasedsource.MAX_NUM_THREADS_FOR_SIZE_ESTIMATION = 2

def test_string_or_value_provider_only(self):
str_file_pattern = tempfile.NamedTemporaryFile(delete=False).name
self.assertEqual(str_file_pattern,
FileBasedSource(str_file_pattern)._pattern.value)

static_vp_file_pattern = StaticValueProvider(value_type=str,
value=str_file_pattern)
self.assertEqual(static_vp_file_pattern,
FileBasedSource(static_vp_file_pattern)._pattern)

runtime_vp_file_pattern = RuntimeValueProvider(
option_name='arg',
value_type=str,
default_value=str_file_pattern,
options_id=1)
self.assertEqual(runtime_vp_file_pattern,
FileBasedSource(runtime_vp_file_pattern)._pattern)

invalid_file_pattern = 123
with self.assertRaises(TypeError):
FileBasedSource(invalid_file_pattern)

def test_validation_file_exists(self):
file_name, _ = write_data(10)
LineSource(file_name)
Expand Down
56 changes: 14 additions & 42 deletions sdks/python/apache_beam/io/fileio.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""File-based sources and sinks."""

from __future__ import absolute_import
Expand All @@ -31,9 +30,6 @@
from apache_beam.io.filesystem import CompressionTypes
from apache_beam.io.filesystems_util import get_filesystem
from apache_beam.transforms.display import DisplayDataItem
from apache_beam.utils.value_provider import ValueProvider
from apache_beam.utils.value_provider import StaticValueProvider
from apache_beam.utils.value_provider import check_accessible

MAX_BATCH_OPERATION_SIZE = 100
DEFAULT_SHARD_NAME_TEMPLATE = '-SSSSS-of-NNNNN'
Expand Down Expand Up @@ -153,41 +149,33 @@ def __init__(self,
compression_type=CompressionTypes.AUTO):
"""
Raises:
TypeError: if file path parameters are not a string or ValueProvider,
or if compression_type is not member of CompressionTypes.
TypeError: if file path parameters are not a string or if compression_type
is not member of CompressionTypes.
ValueError: if shard_name_template is not of expected format.
"""
if not (isinstance(file_path_prefix, basestring)
or isinstance(file_path_prefix, ValueProvider)):
raise TypeError('file_path_prefix must be a string or ValueProvider;'
'got %r instead' % file_path_prefix)
if not (isinstance(file_name_suffix, basestring)
or isinstance(file_name_suffix, ValueProvider)):
raise TypeError('file_name_suffix must be a string or ValueProvider;'
'got %r instead' % file_name_suffix)
if not isinstance(file_path_prefix, basestring):
raise TypeError('file_path_prefix must be a string; got %r instead' %
file_path_prefix)
if not isinstance(file_name_suffix, basestring):
raise TypeError('file_name_suffix must be a string; got %r instead' %
file_name_suffix)

if not CompressionTypes.is_valid_compression_type(compression_type):
raise TypeError('compression_type must be CompressionType object but '
'was %s' % type(compression_type))

if shard_name_template is None:
shard_name_template = DEFAULT_SHARD_NAME_TEMPLATE
elif shard_name_template is '':
num_shards = 1
if isinstance(file_path_prefix, basestring):
file_path_prefix = StaticValueProvider(str, file_path_prefix)
if isinstance(file_name_suffix, basestring):
file_name_suffix = StaticValueProvider(str, file_name_suffix)
self.file_path_prefix = file_path_prefix
self.file_name_suffix = file_name_suffix
self.num_shards = num_shards
self.coder = coder
self.shard_name_format = self._template_to_format(shard_name_template)
self.compression_type = compression_type
self.mime_type = mime_type
if file_path_prefix.is_accessible():
self._file_system = get_filesystem(file_path_prefix.get())
else:
self._file_system = None
self._file_system = get_filesystem(file_path_prefix)

def display_data(self):
return {'shards':
Expand All @@ -201,15 +189,12 @@ def display_data(self):
self.file_name_suffix),
label='File Pattern')}

@check_accessible(['file_path_prefix'])
def open(self, temp_path):
"""Opens ``temp_path``, returning an opaque file handle object.
The returned file handle is passed to ``write_[encoded_]record`` and
``close``.
"""
if self._file_system is None:
self._file_system = get_filesystem(self.file_path_prefix.get())
return self._file_system.create(temp_path, self.mime_type,
self.compression_type)

Expand All @@ -236,33 +221,22 @@ def close(self, file_handle):
if file_handle is not None:
file_handle.close()

@check_accessible(['file_path_prefix', 'file_name_suffix'])
def initialize_write(self):
file_path_prefix = self.file_path_prefix.get()
file_name_suffix = self.file_name_suffix.get()
tmp_dir = file_path_prefix + file_name_suffix + time.strftime(
tmp_dir = self.file_path_prefix + self.file_name_suffix + time.strftime(
'-temp-%Y-%m-%d_%H-%M-%S')
if self._file_system is None:
self._file_system = get_filesystem(file_path_prefix)
self._file_system.mkdirs(tmp_dir)
return tmp_dir

@check_accessible(['file_path_prefix', 'file_name_suffix'])
def open_writer(self, init_result, uid):
# A proper suffix is needed for AUTO compression detection.
# We also ensure there will be no collisions with uid and a
# (possibly unsharded) file_path_prefix and a (possibly empty)
# file_name_suffix.
file_path_prefix = self.file_path_prefix.get()
file_name_suffix = self.file_name_suffix.get()
suffix = (
'.' + os.path.basename(file_path_prefix) + file_name_suffix)
'.' + os.path.basename(self.file_path_prefix) + self.file_name_suffix)
return FileSinkWriter(self, os.path.join(init_result, uid) + suffix)

@check_accessible(['file_path_prefix', 'file_name_suffix'])
def finalize_write(self, init_result, writer_results):
file_path_prefix = self.file_path_prefix.get()
file_name_suffix = self.file_name_suffix.get()
writer_results = sorted(writer_results)
num_shards = len(writer_results)
min_threads = min(num_shards, FileSink._MAX_RENAME_THREADS)
Expand All @@ -272,8 +246,8 @@ def finalize_write(self, init_result, writer_results):
destination_files = []
for shard_num, shard in enumerate(writer_results):
final_name = ''.join([
file_path_prefix, self.shard_name_format % dict(
shard_num=shard_num, num_shards=num_shards), file_name_suffix
self.file_path_prefix, self.shard_name_format % dict(
shard_num=shard_num, num_shards=num_shards), self.file_name_suffix
])
source_files.append(shard)
destination_files.append(final_name)
Expand All @@ -296,8 +270,6 @@ def _rename_batch(batch):
"""_rename_batch executes batch rename operations."""
source_files, destination_files = batch
exceptions = []
if self._file_system is None:
self._file_system = get_filesystem(file_path_prefix)
try:
self._file_system.rename(source_files, destination_files)
return exceptions
Expand Down

0 comments on commit 0f5c363

Please sign in to comment.