Skip to content

Commit

Permalink
support binary by NoneSplitter. (#954)
Browse files Browse the repository at this point in the history
  • Loading branch information
toyama0919 authored and jesterhazy committed Sep 25, 2019
1 parent a999b16 commit d368524
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 5 deletions.
27 changes: 23 additions & 4 deletions src/sagemaker/local/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,20 +221,39 @@ def split(self, file):
class NoneSplitter(Splitter):
"""Does not split records, essentially reads the whole file."""

def split(self, file):
# non-utf8 characters.
_textchars = bytearray({7, 8, 9, 10, 12, 13, 27} | set(range(0x20, 0x100)) - {0x7F})

def split(self, filename):
"""Split a file into records using a specific strategy.
For this NoneSplitter there is no actual split happening and the file
is returned as a whole.
Args:
file (str): path to the file to split
filename (str): path to the file to split
Returns: generator for the individual records that were split from
the file
"""
with open(file, "r") as f:
yield f.read()
with open(filename, "rb") as f:
buf = f.read()
if not self._is_binary(buf):
buf = buf.decode()
yield buf

def _is_binary(self, buf):
"""binary check.
Check whether `buf` contains binary data.
Returns true if `buf` contains any non-utf-8 characters.
Args:
buf (bytes): data to inspect
Returns:
True if data is binary, otherwise False
"""
return bool(buf.translate(None, self._textchars))


class LineSplitter(Splitter):
Expand Down
11 changes: 10 additions & 1 deletion tests/unit/test_local_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,15 +110,24 @@ def test_get_splitter_instance_with_invalid_types():


def test_none_splitter(tmpdir):
splitter = sagemaker.local.data.NoneSplitter()

test_file_path = tmpdir.join("none_test.txt")

with test_file_path.open("w") as f:
f.write("this\nis\na\ntest")

splitter = sagemaker.local.data.NoneSplitter()
data = [x for x in splitter.split(str(test_file_path))]
assert data == ["this\nis\na\ntest"]

test_bin_file_path = tmpdir.join("none_test.bin")

with test_bin_file_path.open("wb") as f:
f.write(b"\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00\x01\x00\x00C")

data = [x for x in splitter.split(str(test_bin_file_path))]
assert data == [b"\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00\x01\x00\x00C"]


def test_line_splitter(tmpdir):
test_file_path = tmpdir.join("line_test.txt")
Expand Down

0 comments on commit d368524

Please sign in to comment.