Skip to content

Commit

Permalink
Merge pull request #102 from mattsb42-aws/dev-24
Browse files Browse the repository at this point in the history
Fix handling of partial reads
  • Loading branch information
mattsb42-aws committed Nov 13, 2018
2 parents 7013f40 + 74355b0 commit 1875c91
Show file tree
Hide file tree
Showing 12 changed files with 262 additions and 42 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@ Changelog
Minor
-----

* Add support to remove clients from :ref:`KMSMasterKeyProvider` client cache if they fail to connect to endpoint.
* Add support to remove clients from :class:`KMSMasterKeyProvider` client cache if they fail to connect to endpoint.
`#86 <https://github.com/aws/aws-encryption-sdk-python/pull/86>`_
* Add support for SHA384 and SHA512 for use with RSA OAEP wrapping algorithms.
`#56 <https://github.com/aws/aws-encryption-sdk-python/issues/56>`_
* Fix ``streaming_client`` classes to properly interpret short reads in source streams.
`#24 <https://github.com/aws/aws-encryption-sdk-python/issues/24>`_

1.3.7 -- 2018-09-20
===================
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ branch = True
show_missing = True

[tool:pytest]
log_level = DEBUG
markers =
local: superset of unit and functional (does not require network access)
unit: mark test as a unit test (does not require network access)
Expand Down
10 changes: 7 additions & 3 deletions src/aws_encryption_sdk/internal/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from aws_encryption_sdk.internal.str_ops import to_bytes
from aws_encryption_sdk.structures import EncryptedDataKey

from .streams import InsistentReaderBytesIO

_LOGGER = logging.getLogger(__name__)


Expand Down Expand Up @@ -132,12 +134,14 @@ def prep_stream_data(data):
:param data: Input data
:returns: Prepared stream
:rtype: io.BytesIO
:rtype: InsistentReaderBytesIO
"""
if isinstance(data, (six.string_types, six.binary_type)):
return io.BytesIO(to_bytes(data))
stream = io.BytesIO(to_bytes(data))
else:
stream = data

return data
return InsistentReaderBytesIO(stream)


def source_data_key_length_check(source_data_key, algorithm):
Expand Down
41 changes: 41 additions & 0 deletions src/aws_encryption_sdk/internal/utils/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Helper stream utility objects for AWS Encryption SDK."""
import io

from wrapt import ObjectProxy

from aws_encryption_sdk.exceptions import ActionNotAllowedError
from aws_encryption_sdk.internal.str_ops import to_bytes


class ROStream(ObjectProxy):
Expand Down Expand Up @@ -56,3 +59,41 @@ def read(self, b=None):
data = self.__wrapped__.read(b)
self.__tee.write(data)
return data


class InsistentReaderBytesIO(ObjectProxy):
"""Wrapper around a readable stream that insists on reading exactly the requested
number of bytes. It will keep trying to read bytes from the wrapped stream until
either the requested number of bytes are available or the wrapped stream has
nothing more to return.
:param wrapped: File-like object
"""

def read(self, b=-1):
"""Keep reading from source stream until either the source stream is done
or the requested number of bytes have been obtained.
:param int b: number of bytes to read
:return: All bytes read from wrapped stream
:rtype: bytes
"""
remaining_bytes = b
data = io.BytesIO()
while True:
try:
chunk = to_bytes(self.__wrapped__.read(remaining_bytes))
except ValueError:
if self.__wrapped__.closed:
break
raise

if not chunk:
break

data.write(chunk)
remaining_bytes -= len(chunk)

if remaining_bytes <= 0:
break
return data.getvalue()
47 changes: 33 additions & 14 deletions src/aws_encryption_sdk/streaming_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,24 +202,28 @@ def readable(self):
# Open streams are currently always readable.
return not self.closed

def read(self, b=None):
def read(self, b=-1):
"""Returns either the requested number of bytes or the entire stream.
:param int b: Number of bytes to read
:returns: Processed (encrypted or decrypted) bytes from source stream
:rtype: bytes
"""
# Any negative value for b is interpreted as a full read
if b is not None and b < 0:
b = None
# None is also accepted for legacy compatibility
if b is None or b < 0:
b = -1

_LOGGER.debug("Stream read called, requesting %s bytes", b)
output = io.BytesIO()

if not self._message_prepped:
self._prep_message()

if self.closed:
raise ValueError("I/O operation on closed file")
if b:

if b >= 0:
self._read_bytes(b)
output.write(self.output_buffer[:b])
self.output_buffer = self.output_buffer[b:]
Expand All @@ -228,6 +232,7 @@ def read(self, b=None):
self._read_bytes(LINE_LENGTH)
output.write(self.output_buffer)
self.output_buffer = b""

self.bytes_read += output.tell()
_LOGGER.debug("Returning %s bytes of %s bytes requested", output.tell(), b)
return output.getvalue()
Expand Down Expand Up @@ -511,14 +516,18 @@ def _read_bytes_to_non_framed_body(self, b):
_LOGGER.debug("Closing encryptor after receiving only %s bytes of %s bytes requested", plaintext, b)
self.source_stream.close()
closing = self.encryptor.finalize()

if self.signer is not None:
self.signer.update(closing)

closing += aws_encryption_sdk.internal.formatting.serialize.serialize_non_framed_close(
tag=self.encryptor.tag, signer=self.signer
)

if self.signer is not None:
closing += aws_encryption_sdk.internal.formatting.serialize.serialize_footer(self.signer)
return ciphertext + closing

return ciphertext

def _read_bytes_to_framed_body(self, b):
Expand All @@ -530,14 +539,22 @@ def _read_bytes_to_framed_body(self, b):
"""
_LOGGER.debug("collecting %s bytes", b)
_b = b
b = int(math.ceil(b / float(self.config.frame_length)) * self.config.frame_length)
_LOGGER.debug("%s bytes requested; reading %s bytes after normalizing to frame length", _b, b)

if b > 0:
_frames_to_read = math.ceil(b / float(self.config.frame_length))
b = int(_frames_to_read * self.config.frame_length)
_LOGGER.debug("%d bytes requested; reading %d bytes after normalizing to frame length", _b, b)

plaintext = self.source_stream.read(b)
_LOGGER.debug("%s bytes read from source", len(plaintext))
plaintext_length = len(plaintext)
_LOGGER.debug("%d bytes read from source", plaintext_length)

finalize = False
if len(plaintext) < b:

if b < 0 or plaintext_length < b:
_LOGGER.debug("Final plaintext read from source")
finalize = True

output = b""
final_frame_written = False

Expand Down Expand Up @@ -583,8 +600,8 @@ def _read_bytes(self, b):
:param int b: Number of bytes to read
:raises NotSupportedError: if content type is not supported
"""
_LOGGER.debug("%s bytes requested from stream with content type: %s", b, self.content_type)
if b <= len(self.output_buffer) or self.source_stream.closed:
_LOGGER.debug("%d bytes requested from stream with content type: %s", b, self.content_type)
if 0 <= b <= len(self.output_buffer) or self.source_stream.closed:
_LOGGER.debug("No need to read from source stream or source stream closed")
return

Expand Down Expand Up @@ -776,10 +793,13 @@ def _read_bytes_from_non_framed_body(self, b):
bytes_to_read = self.body_end - self.source_stream.tell()
_LOGGER.debug("%s bytes requested; reading %s bytes", b, bytes_to_read)
ciphertext = self.source_stream.read(bytes_to_read)

if len(self.output_buffer) + len(ciphertext) < self.body_length:
raise SerializationError("Total message body contents less than specified in body description")

if self.verifier is not None:
self.verifier.update(ciphertext)

plaintext = self.decryptor.update(ciphertext)
plaintext += self.decryptor.finalize()
aws_encryption_sdk.internal.formatting.deserialize.update_verifier_with_tag(
Expand Down Expand Up @@ -844,10 +864,9 @@ def _read_bytes(self, b):
_LOGGER.debug("Source stream closed")
return

if b <= len(self.output_buffer):
_LOGGER.debug(
"%s bytes requested less than or equal to current output buffer size %s", b, len(self.output_buffer)
)
buffer_length = len(self.output_buffer)
if 0 <= b <= buffer_length:
_LOGGER.debug("%d bytes requested less than or equal to current output buffer size %d", b, buffer_length)
return

if self._header.content_type == ContentType.FRAMED_DATA:
Expand Down
71 changes: 71 additions & 0 deletions test/functional/test_f_aws_encryption_sdk_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,3 +598,74 @@ def test_stream_decryptor_readable():
assert handler.readable()
handler.read()
assert not handler.readable()


def exact_length_plaintext(length):
plaintext = b""
while len(plaintext) < length:
plaintext += VALUES["plaintext_128"]
return plaintext[:length]


class SometimesIncompleteReaderIO(io.BytesIO):
def __init__(self, *args, **kwargs):
self.__read_counter = 0
super(SometimesIncompleteReaderIO, self).__init__(*args, **kwargs)

def read(self, size=-1):
"""Every other read request, return fewer than the requested number of bytes if more than one byte requested."""
self.__read_counter += 1
if size > 1 and self.__read_counter % 2 == 0:
size //= 2
return super(SometimesIncompleteReaderIO, self).read(size)


@pytest.mark.parametrize(
"frame_length",
(
0, # 0: unframed
128, # 128: framed with exact final frame size match
256, # 256: framed with inexact final frame size match
),
)
def test_incomplete_read_stream_cycle(frame_length):
chunk_size = 21 # Will never be an exact match for the frame size
key_provider = fake_kms_key_provider()

plaintext = exact_length_plaintext(384)
ciphertext = b""
cycle_count = 0
with aws_encryption_sdk.stream(
mode="encrypt",
source=SometimesIncompleteReaderIO(plaintext),
key_provider=key_provider,
frame_length=frame_length,
) as encryptor:
while True:
cycle_count += 1
chunk = encryptor.read(chunk_size)
if not chunk:
break
ciphertext += chunk
if cycle_count > len(VALUES["plaintext_128"]):
raise aws_encryption_sdk.exceptions.AWSEncryptionSDKClientError(
"Unexpected error encrypting message: infinite loop detected."
)

decrypted = b""
cycle_count = 0
with aws_encryption_sdk.stream(
mode="decrypt", source=SometimesIncompleteReaderIO(ciphertext), key_provider=key_provider
) as decryptor:
while True:
cycle_count += 1
chunk = decryptor.read(chunk_size)
if not chunk:
break
decrypted += chunk
if cycle_count > len(VALUES["plaintext_128"]):
raise aws_encryption_sdk.exceptions.AWSEncryptionSDKClientError(
"Unexpected error encrypting message: infinite loop detected."
)

assert ciphertext != decrypted == plaintext
19 changes: 11 additions & 8 deletions test/unit/test_streaming_client_encryption_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import aws_encryption_sdk.exceptions
from aws_encryption_sdk.internal.defaults import LINE_LENGTH
from aws_encryption_sdk.internal.utils.streams import InsistentReaderBytesIO
from aws_encryption_sdk.key_providers.base import MasterKeyProvider
from aws_encryption_sdk.streaming_client import _ClientConfig, _EncryptionStream

Expand Down Expand Up @@ -107,17 +108,19 @@ def test_new_with_params(self):
line_length=io.DEFAULT_BUFFER_SIZE,
source_length=mock_int_sentinel,
)
assert mock_stream.config == MockClientConfig(
source=self.mock_source_stream,
key_provider=self.mock_key_provider,
mock_read_bytes=sentinel.read_bytes,
line_length=io.DEFAULT_BUFFER_SIZE,
source_length=mock_int_sentinel,
)

assert mock_stream.config.source == self.mock_source_stream
assert isinstance(mock_stream.config.source, InsistentReaderBytesIO)
assert mock_stream.config.key_provider is self.mock_key_provider
assert mock_stream.config.mock_read_bytes is sentinel.read_bytes
assert mock_stream.config.line_length == io.DEFAULT_BUFFER_SIZE
assert mock_stream.config.source_length is mock_int_sentinel

assert mock_stream.bytes_read == 0
assert mock_stream.output_buffer == b""
assert not mock_stream._message_prepped
assert mock_stream.source_stream is self.mock_source_stream
assert mock_stream.source_stream == self.mock_source_stream
assert isinstance(mock_stream.source_stream, InsistentReaderBytesIO)
assert mock_stream._stream_length is mock_int_sentinel
assert mock_stream.line_length == io.DEFAULT_BUFFER_SIZE

Expand Down
2 changes: 1 addition & 1 deletion test/unit/test_streaming_client_stream_decryptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def test_prep_non_framed(self):
test_decryptor._prep_non_framed()

self.mock_deserialize_non_framed_values.assert_called_once_with(
stream=self.mock_input_stream, header=self.mock_header, verifier=sentinel.verifier
stream=test_decryptor.source_stream, header=self.mock_header, verifier=sentinel.verifier
)
assert test_decryptor.body_length == len(VALUES["data_128"])
self.mock_get_aad_content_string.assert_called_once_with(
Expand Down

0 comments on commit 1875c91

Please sign in to comment.