Skip to content

Commit

Permalink
Merge pull request #98 from mattsb42-aws/read
Browse files Browse the repository at this point in the history
Properly handle negative values to _EncryptionStream.read()
  • Loading branch information
mattsb42-aws committed Nov 9, 2018
2 parents 66d06e4 + 870707c commit 7013f40
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 39 deletions.
4 changes: 4 additions & 0 deletions src/aws_encryption_sdk/streaming_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,10 @@ def read(self, b=None):
:returns: Processed (encrypted or decrypted) bytes from source stream
:rtype: bytes
"""
# Any negative value for b is interpreted as a full read
if b is not None and b < 0:
b = None

_LOGGER.debug("Stream read called, requesting %s bytes", b)
output = io.BytesIO()
if not self._message_prepped:
Expand Down
107 changes: 68 additions & 39 deletions test/unit/test_streaming_client_encryption_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,11 @@
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Unit test suite for aws_encryption_sdk.streaming_client._EncryptionStream"""
import copy
import io
import unittest

import attr
import pytest
import six
from mock import MagicMock, PropertyMock, call, patch, sentinel

import aws_encryption_sdk.exceptions
Expand Down Expand Up @@ -45,20 +44,22 @@ def _read_bytes(self, b):
return self.config.mock_read_bytes


class TestEncryptionStream(unittest.TestCase):
def setUp(self):
self.mock_source_stream = MagicMock()
self.mock_source_stream.__class__ = io.IOBase
self.mock_source_stream.tell.side_effect = (10, 500)
class TestEncryptionStream(object):
def _mock_key_provider(self):
mock_key_provider = MagicMock()
mock_key_provider.__class__ = MasterKeyProvider
return mock_key_provider

self.mock_key_provider = MagicMock()
self.mock_key_provider.__class__ = MasterKeyProvider
def _mock_source_stream(self):
mock_source_stream = MagicMock()
mock_source_stream.__class__ = io.IOBase
mock_source_stream.tell.side_effect = (10, 500)
return mock_source_stream

self.mock_line_length = MagicMock()
self.mock_line_length.__class__ = int

self.mock_source_length = MagicMock()
self.mock_source_length.__class__ = int
@pytest.fixture(autouse=True)
def apply_fixtures(self):
self.mock_key_provider = self._mock_key_provider()
self.mock_source_stream = self._mock_source_stream()

def test_read_bytes_enforcement(self):
class TestStream(_EncryptionStream):
Expand All @@ -67,19 +68,23 @@ class TestStream(_EncryptionStream):
def _prep_message(self):
pass

with six.assertRaisesRegex(self, TypeError, "Can't instantiate abstract class TestStream"):
with pytest.raises(TypeError) as excinfo:
TestStream()

excinfo.match("Can't instantiate abstract class TestStream")

def test_prep_message_enforcement(self):
class TestStream(_EncryptionStream):
_config_class = MockClientConfig

def _read_bytes(self):
pass

with six.assertRaisesRegex(self, TypeError, "Can't instantiate abstract class TestStream"):
with pytest.raises(TypeError) as excinfo:
TestStream()

excinfo.match("Can't instantiate abstract class TestStream")

def test_config_class_enforcement(self):
class TestStream(_EncryptionStream):
def _read_bytes(self):
Expand All @@ -88,29 +93,32 @@ def _read_bytes(self):
def _prep_message(self):
pass

with six.assertRaisesRegex(self, TypeError, "Can't instantiate abstract class TestStream"):
with pytest.raises(TypeError) as excinfo:
TestStream()

excinfo.match("Can't instantiate abstract class TestStream")

def test_new_with_params(self):
mock_int_sentinel = MagicMock(__class__=int)
mock_stream = MockEncryptionStream(
source=self.mock_source_stream,
key_provider=self.mock_key_provider,
mock_read_bytes=sentinel.read_bytes,
line_length=self.mock_line_length,
source_length=self.mock_source_length,
line_length=io.DEFAULT_BUFFER_SIZE,
source_length=mock_int_sentinel,
)
assert mock_stream.config == MockClientConfig(
source=self.mock_source_stream,
key_provider=self.mock_key_provider,
mock_read_bytes=sentinel.read_bytes,
line_length=self.mock_line_length,
source_length=self.mock_source_length,
line_length=io.DEFAULT_BUFFER_SIZE,
source_length=mock_int_sentinel,
)
assert mock_stream.bytes_read == 0
assert mock_stream.output_buffer == b""
assert not mock_stream._message_prepped
assert mock_stream.source_stream is self.mock_source_stream
assert mock_stream._stream_length is self.mock_source_length
assert mock_stream._stream_length is mock_int_sentinel
assert mock_stream.line_length == io.DEFAULT_BUFFER_SIZE

def test_new_with_config(self):
Expand Down Expand Up @@ -154,7 +162,8 @@ class CustomUnknownError(Exception):
mock_stream = MockEncryptionStream(
source=self.mock_source_stream, key_provider=self.mock_key_provider, mock_read_bytes=sentinel.read_bytes
)
with self.assertRaises(CustomUnknownError):

with pytest.raises(CustomUnknownError):
mock_stream.__exit__(None, None, None)

def test_stream_length(self):
Expand All @@ -173,9 +182,12 @@ def test_stream_length_unsupported(self):
mock_stream = MockEncryptionStream(
source=self.mock_source_stream, key_provider=self.mock_key_provider, mock_read_bytes=sentinel.read_bytes
)
with six.assertRaisesRegex(self, aws_encryption_sdk.exceptions.NotSupportedError, "Unexpected exception!"):

with pytest.raises(aws_encryption_sdk.exceptions.NotSupportedError) as excinfo:
mock_stream.stream_length # pylint: disable=pointless-statement

excinfo.match("Unexpected exception!")

def test_header_property(self):
mock_prep_message = MagicMock()
mock_stream = MockEncryptionStream(
Expand Down Expand Up @@ -205,31 +217,37 @@ def test_read_closed(self):
source=self.mock_source_stream, key_provider=self.mock_key_provider, mock_read_bytes=sentinel.read_bytes
)
mock_stream.close()
with six.assertRaisesRegex(self, ValueError, "I/O operation on closed file"):

with pytest.raises(ValueError) as excinfo:
mock_stream.read()

def test_read_b(self):
excinfo.match("I/O operation on closed file")

@pytest.mark.parametrize("bytes_to_read", range(1, 11))
def test_read_b(self, bytes_to_read):
mock_stream = MockEncryptionStream(
source=io.BytesIO(VALUES["data_128"]),
key_provider=self.mock_key_provider,
mock_read_bytes=sentinel.read_bytes,
)
data = b"1234567890"
mock_stream._read_bytes = MagicMock()
mock_stream.output_buffer = b"1234567890"
test = mock_stream.read(5)
mock_stream._read_bytes.assert_called_once_with(5)
assert test == b"12345"
assert mock_stream.output_buffer == b"67890"

def test_read_all(self):
mock_stream.output_buffer = copy.copy(data)
test = mock_stream.read(bytes_to_read)
mock_stream._read_bytes.assert_called_once_with(bytes_to_read)
assert test == data[:bytes_to_read]
assert mock_stream.output_buffer == data[bytes_to_read:]

@pytest.mark.parametrize("bytes_to_read", (None, -1, -99))
def test_read_all(self, bytes_to_read):
mock_stream = MockEncryptionStream(
source=self.mock_source_stream, key_provider=self.mock_key_provider, mock_read_bytes=sentinel.read_bytes
)
mock_stream._stream_length = 5
mock_stream.output_buffer = b"1234567890"
mock_stream.source_stream = MagicMock()
type(mock_stream.source_stream).closed = PropertyMock(side_effect=(False, False, True))
test = mock_stream.read()
test = mock_stream.read(bytes_to_read)
assert test == b"1234567890"

def test_read_all_empty_source(self):
Expand Down Expand Up @@ -262,23 +280,32 @@ def test_writelines(self):
mock_stream = MockEncryptionStream(
source=self.mock_source_stream, key_provider=self.mock_key_provider, mock_read_bytes=sentinel.read_bytes
)
with six.assertRaisesRegex(self, NotImplementedError, "writelines is not available for this object"):

with pytest.raises(NotImplementedError) as excinfo:
mock_stream.writelines(None)

excinfo.match("writelines is not available for this object")

def test_write(self):
mock_stream = MockEncryptionStream(
source=self.mock_source_stream, key_provider=self.mock_key_provider, mock_read_bytes=sentinel.read_bytes
)
with six.assertRaisesRegex(self, NotImplementedError, "write is not available for this object"):

with pytest.raises(NotImplementedError) as excinfo:
mock_stream.write(None)

excinfo.match("write is not available for this object")

def test_seek(self):
mock_stream = MockEncryptionStream(
source=self.mock_source_stream, key_provider=self.mock_key_provider, mock_read_bytes=sentinel.read_bytes
)
with six.assertRaisesRegex(self, NotImplementedError, "seek is not available for this object"):

with pytest.raises(NotImplementedError) as excinfo:
mock_stream.seek(None)

excinfo.match("seek is not available for this object")

def test_readline(self):
test_line = "TEST_LINE_AAAA"
test_line_length = len(test_line)
Expand Down Expand Up @@ -319,7 +346,8 @@ def test_next_stream_closed(self):
source=self.mock_source_stream, key_provider=self.mock_key_provider, mock_read_bytes=sentinel.read_bytes
)
mock_stream.close()
with self.assertRaises(StopIteration):

with pytest.raises(StopIteration):
mock_stream.next()

def test_next_source_stream_closed_and_buffer_empty(self):
Expand All @@ -328,7 +356,8 @@ def test_next_source_stream_closed_and_buffer_empty(self):
)
self.mock_source_stream.closed = True
mock_stream.output_buffer = b""
with self.assertRaises(StopIteration):

with pytest.raises(StopIteration):
mock_stream.next()

@patch("aws_encryption_sdk.streaming_client._EncryptionStream.closed", new_callable=PropertyMock)
Expand Down
1 change: 1 addition & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ commands =
accept: {[testenv:base-command]commands} test/ -m accept
examples: {[testenv:base-command]commands} examples/test/ -m examples
all: {[testenv:base-command]commands} test/ examples/test/
manual: {[testenv:base-command]commands}

# Verify that local tests work without environment variables present
[testenv:nocmk]
Expand Down

0 comments on commit 7013f40

Please sign in to comment.