Skip to content

Commit

Permalink
[BEAM-1010] A few improvements to Apache Beam Python's FileIO.
Browse files Browse the repository at this point in the history
This closes #1392
  • Loading branch information
lukecwik committed Nov 21, 2016
2 parents c1440f7 + 6aa50c1 commit 8e88c7b
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 10 deletions.
14 changes: 13 additions & 1 deletion sdks/python/apache_beam/io/fileio.py
Expand Up @@ -749,6 +749,12 @@ def flush(self):
def seekable(self):
return False

def __enter__(self):
return self

def __exit__(self, exception_type, exception_value, traceback):
self.close()


class FileSink(iobase.Sink):
"""A sink to a GCS or local files.
Expand Down Expand Up @@ -855,7 +861,13 @@ def initialize_write(self):
return tmp_dir

def open_writer(self, init_result, uid):
return FileSinkWriter(self, os.path.join(init_result, uid))
# A proper suffix is needed for AUTO compression detection.
# We also ensure there will be no collisions with uid and a
# (possibly unsharded) file_path_prefix and a (possibly empty)
# file_name_suffix.
suffix = (
'.' + os.path.basename(self.file_path_prefix) + self.file_name_suffix)
return FileSinkWriter(self, os.path.join(init_result, uid) + suffix)

def finalize_write(self, init_result, writer_results):
writer_results = sorted(writer_results)
Expand Down
48 changes: 44 additions & 4 deletions sdks/python/apache_beam/io/fileio_test.py
Expand Up @@ -38,10 +38,7 @@
from apache_beam.transforms.display_test import DisplayDataItemMatcher

# TODO: Add tests for file patterns (ie not just individual files) for both
# uncompressed

# TODO: Update code to not use NamedTemporaryFile (or to use it in a way that
# doesn't violate its assumptions).
# compressed and uncompressed files.


class TestTextFileSource(unittest.TestCase):
Expand Down Expand Up @@ -721,6 +718,49 @@ def test_write_text_bzip2_file_empty(self):
with bz2.BZ2File(self.path, 'r') as f:
self.assertEqual(f.read().splitlines(), [])

def test_write_dataflow(self):
pipeline = beam.Pipeline('DirectPipelineRunner')
pcoll = pipeline | beam.core.Create('Create', self.lines)
pcoll | 'Write' >> beam.Write(fileio.NativeTextFileSink(self.path)) # pylint: disable=expression-not-assigned
pipeline.run()

read_result = []
for file_name in glob.glob(self.path + '*'):
with open(file_name, 'r') as f:
read_result.extend(f.read().splitlines())

self.assertEqual(read_result, self.lines)

def test_write_dataflow_auto_compression(self):
pipeline = beam.Pipeline('DirectPipelineRunner')
pcoll = pipeline | beam.core.Create('Create', self.lines)
pcoll | 'Write' >> beam.Write( # pylint: disable=expression-not-assigned
fileio.NativeTextFileSink(
self.path, file_name_suffix='.gz'))
pipeline.run()

read_result = []
for file_name in glob.glob(self.path + '*'):
with gzip.GzipFile(file_name, 'r') as f:
read_result.extend(f.read().splitlines())

self.assertEqual(read_result, self.lines)

def test_write_dataflow_auto_compression_unsharded(self):
pipeline = beam.Pipeline('DirectPipelineRunner')
pcoll = pipeline | beam.core.Create('Create', self.lines)
pcoll | 'Write' >> beam.Write( # pylint: disable=expression-not-assigned
fileio.NativeTextFileSink(
self.path + '.gz', shard_name_template=''))
pipeline.run()

read_result = []
for file_name in glob.glob(self.path + '*'):
with gzip.GzipFile(file_name, 'r') as f:
read_result.extend(f.read().splitlines())

self.assertEqual(read_result, self.lines)


class MyFileSink(fileio.FileSink):

Expand Down
6 changes: 1 addition & 5 deletions sdks/python/apache_beam/io/textio.py
Expand Up @@ -85,11 +85,9 @@ def __init__(self, file_pattern, min_bundle_size,

def read_records(self, file_name, range_tracker):
start_offset = range_tracker.start_position()

read_buffer = _TextSource.ReadBuffer('', 0)
file_to_read = self.open_file(file_name)

try:
with self.open_file(file_name) as file_to_read:
if start_offset > 0:
# Seeking to one position before the start index and ignoring the
# current line. If start_position is at beginning if the line, that line
Expand All @@ -116,8 +114,6 @@ def read_records(self, file_name, range_tracker):
if num_bytes_to_next_record < 0:
break
next_record_start_position += num_bytes_to_next_record
finally:
file_to_read.close()

def _find_separator_bounds(self, file_to_read, read_buffer):
# Determines the start and end positions within 'read_buffer.data' of the
Expand Down
26 changes: 26 additions & 0 deletions sdks/python/apache_beam/io/textio_test.py
Expand Up @@ -491,6 +491,32 @@ def test_write_dataflow(self):

self.assertEqual(read_result, self.lines)

def test_write_dataflow_auto_compression(self):
pipeline = beam.Pipeline('DirectPipelineRunner')
pcoll = pipeline | beam.core.Create('Create', self.lines)
pcoll | 'Write' >> WriteToText(self.path, file_name_suffix='.gz') # pylint: disable=expression-not-assigned
pipeline.run()

read_result = []
for file_name in glob.glob(self.path + '*'):
with gzip.GzipFile(file_name, 'r') as f:
read_result.extend(f.read().splitlines())

self.assertEqual(read_result, self.lines)

def test_write_dataflow_auto_compression_unsharded(self):
pipeline = beam.Pipeline('DirectPipelineRunner')
pcoll = pipeline | beam.core.Create('Create', self.lines)
pcoll | 'Write' >> WriteToText(self.path + '.gz', shard_name_template='') # pylint: disable=expression-not-assigned
pipeline.run()

read_result = []
for file_name in glob.glob(self.path + '*'):
with gzip.GzipFile(file_name, 'r') as f:
read_result.extend(f.read().splitlines())

self.assertEqual(read_result, self.lines)


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
Expand Down

0 comments on commit 8e88c7b

Please sign in to comment.