Skip to content

Commit

Permalink
Merge pull request #106 from mattsb42-aws/log-cleanup
Browse files Browse the repository at this point in the history
Log cleanup
  • Loading branch information
mattsb42-aws committed Nov 15, 2018
2 parents 8f047fb + a5415cb commit fd07fbb
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 19 deletions.
2 changes: 1 addition & 1 deletion src/aws_encryption_sdk/internal/formatting/deserialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def deserialize_frame(stream, header, verifier=None):
(sequence_number,) = unpack_values(">I", stream, verifier)
final_frame = True
else:
_LOGGER.debug("Deserializing frame sequence number %s", int(sequence_number))
_LOGGER.debug("Deserializing frame sequence number %d", int(sequence_number))
frame_data["final_frame"] = final_frame
frame_data["sequence_number"] = sequence_number
(frame_iv,) = unpack_values(">{iv_len}s".format(iv_len=header.algorithm.iv_len), stream, verifier)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def deserialize_encryption_context(serialized_encryption_context):
encryption_context = {}

dict_size, deserialized_size = read_short(source=serialized_encryption_context, offset=deserialized_size)
_LOGGER.debug("Found %s keys", dict_size)
_LOGGER.debug("Found %d keys", dict_size)
for _ in range(dict_size):
key_size, deserialized_size = read_short(source=serialized_encryption_context, offset=deserialized_size)
key, deserialized_size = read_string(
Expand Down
26 changes: 14 additions & 12 deletions src/aws_encryption_sdk/streaming_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def read(self, b=-1):
if b is None or b < 0:
b = -1

_LOGGER.debug("Stream read called, requesting %s bytes", b)
_LOGGER.debug("Stream read called, requesting %d bytes", b)
output = io.BytesIO()

if not self._message_prepped:
Expand All @@ -234,7 +234,7 @@ def read(self, b=-1):
self.output_buffer = b""

self.bytes_read += output.tell()
_LOGGER.debug("Returning %s bytes of %s bytes requested", output.tell(), b)
_LOGGER.debug("Returning %d bytes of %d bytes requested", output.tell(), b)
return output.getvalue()

def tell(self):
Expand Down Expand Up @@ -538,7 +538,7 @@ def _read_bytes_to_framed_body(self, b):
:returns: Bytes read from source stream, encrypted, and serialized
:rtype: bytes
"""
_LOGGER.debug("collecting %s bytes", b)
_LOGGER.debug("collecting %d bytes", b)
_b = b

if b > 0:
Expand All @@ -565,10 +565,11 @@ def _read_bytes_to_framed_body(self, b):
# If finalizing on this pass, wait until final frame is written
or (finalize and not final_frame_written)
):
is_final_frame = finalize and len(plaintext) < self.config.frame_length
bytes_in_frame = min(len(plaintext), self.config.frame_length)
current_plaintext_length = len(plaintext)
is_final_frame = finalize and current_plaintext_length < self.config.frame_length
bytes_in_frame = min(current_plaintext_length, self.config.frame_length)
_LOGGER.debug(
"Writing %s bytes into%s frame %s",
"Writing %d bytes into%s frame %d",
bytes_in_frame,
" final" if is_final_frame else "",
self.sequence_number,
Expand Down Expand Up @@ -719,7 +720,7 @@ def _read_header(self):
and header.frame_length > self.config.max_body_length
):
raise CustomMaximumValueExceeded(
"Frame Size in header found larger than custom value: {found} > {custom}".format(
"Frame Size in header found larger than custom value: {found:d} > {custom:d}".format(
found=header.frame_length, custom=self.config.max_body_length
)
)
Expand Down Expand Up @@ -758,7 +759,7 @@ def _prep_non_framed(self):

if self.config.max_body_length is not None and self.body_length > self.config.max_body_length:
raise CustomMaximumValueExceeded(
"Non-framed message content length found larger than custom value: {found} > {custom}".format(
"Non-framed message content length found larger than custom value: {found:d} > {custom:d}".format(
found=self.body_length, custom=self.config.max_body_length
)
)
Expand Down Expand Up @@ -792,7 +793,7 @@ def _read_bytes_from_non_framed_body(self, b):
_LOGGER.debug("starting non-framed body read")
# Always read the entire message for non-framed message bodies.
bytes_to_read = self.body_end - self.source_stream.tell()
_LOGGER.debug("%s bytes requested; reading %s bytes", b, bytes_to_read)
_LOGGER.debug("%d bytes requested; reading %d bytes", b, bytes_to_read)
ciphertext = self.source_stream.read(bytes_to_read)

if len(self.output_buffer) + len(ciphertext) < self.body_length:
Expand Down Expand Up @@ -821,13 +822,13 @@ def _read_bytes_from_framed_body(self, b):
"""
plaintext = b""
final_frame = False
_LOGGER.debug("collecting %s bytes", b)
_LOGGER.debug("collecting %d bytes", b)
while len(plaintext) < b and not final_frame:
_LOGGER.debug("Reading frame")
frame_data, final_frame = aws_encryption_sdk.internal.formatting.deserialize.deserialize_frame(
stream=self.source_stream, header=self._header, verifier=self.verifier
)
_LOGGER.debug("Read complete for frame %s", frame_data.sequence_number)
_LOGGER.debug("Read complete for frame %d", frame_data.sequence_number)
if frame_data.sequence_number != self.last_sequence_number + 1:
raise SerializationError("Malformed message: frames out of order")
self.last_sequence_number += 1
Expand All @@ -846,7 +847,8 @@ def _read_bytes_from_framed_body(self, b):
encrypted_data=frame_data,
associated_data=associated_data,
)
_LOGGER.debug("bytes collected: %s", len(plaintext))
plaintext_length = len(plaintext)
_LOGGER.debug("bytes collected: %d", plaintext_length)
if final_frame:
_LOGGER.debug("Reading footer")
self.footer = aws_encryption_sdk.internal.formatting.deserialize.deserialize_footer(
Expand Down
4 changes: 2 additions & 2 deletions src/aws_encryption_sdk/structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class RawDataKey(object):
"""

key_provider = attr.ib(hash=True, validator=attr.validators.instance_of(MasterKeyInfo))
data_key = attr.ib(hash=True, validator=attr.validators.instance_of(bytes))
data_key = attr.ib(hash=True, repr=False, validator=attr.validators.instance_of(bytes))


@attr.s(hash=True)
Expand All @@ -89,7 +89,7 @@ class DataKey(object):
"""

key_provider = attr.ib(hash=True, validator=attr.validators.instance_of(MasterKeyInfo))
data_key = attr.ib(hash=True, validator=attr.validators.instance_of(bytes))
data_key = attr.ib(hash=True, repr=False, validator=attr.validators.instance_of(bytes))
encrypted_data_key = attr.ib(hash=True, validator=attr.validators.instance_of(bytes))


Expand Down
24 changes: 21 additions & 3 deletions test/functional/test_f_aws_encryption_sdk_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,9 @@ def test_encrypt_ciphertext_message(frame_length, algorithm, encryption_context)
(WrappingAlgorithm.RSA_OAEP_SHA1_MGF1, EncryptionKeyType.PUBLIC, EncryptionKeyType.PRIVATE),
),
)
def test_encryption_cycle_raw_mkp(wrapping_algorithm, encryption_key_type, decryption_key_type):
def test_encryption_cycle_raw_mkp(caplog, wrapping_algorithm, encryption_key_type, decryption_key_type):
caplog.set_level(logging.DEBUG)

encrypting_key_provider = build_fake_raw_key_provider(wrapping_algorithm, encryption_key_type)
decrypting_key_provider = build_fake_raw_key_provider(wrapping_algorithm, decryption_key_type)
ciphertext, _ = aws_encryption_sdk.encrypt(
Expand All @@ -334,7 +336,10 @@ def test_encryption_cycle_raw_mkp(wrapping_algorithm, encryption_key_type, decry
frame_length=0,
)
plaintext, _ = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=decrypting_key_provider)

assert plaintext == VALUES["plaintext_128"]
for member in encrypting_key_provider._members:
assert repr(member.config.wrapping_key._wrapping_key)[2:-1] not in caplog.text


@pytest.mark.skipif(
Expand Down Expand Up @@ -685,7 +690,11 @@ def _prep_plaintext_and_logs(log_catcher, plaintext_length):


def _look_in_logs(log_catcher, plaintext):
# Verify that no plaintext chunks are in the logs
logs = log_catcher.text
# look for all fake KMS data keys
for args in VALUES["data_keys"].values():
assert repr(args["plaintext"])[2:-1] not in logs
# look for every possible 32-byte chunk
start = 0
end = 32
Expand All @@ -698,25 +707,33 @@ def _look_in_logs(log_catcher, plaintext):
end += 1


def _error_check(capsys_instance):
# Verify that no error were caught and ignored.
# The intent is to catch logging errors, but others will be caught as well.
stderr = capsys_instance.readouterr().err
assert "Call stack:" not in stderr


@pytest.mark.parametrize("frame_size", (0, LINE_LENGTH // 2, LINE_LENGTH, LINE_LENGTH * 2))
@pytest.mark.parametrize(
"plaintext_length", (1, LINE_LENGTH // 2, LINE_LENGTH, int(LINE_LENGTH * 1.5), LINE_LENGTH * 2)
)
def test_plaintext_logs_oneshot(caplog, plaintext_length, frame_size):
def test_plaintext_logs_oneshot(caplog, capsys, plaintext_length, frame_size):
plaintext, key_provider = _prep_plaintext_and_logs(caplog, plaintext_length)

_ciphertext, _header = aws_encryption_sdk.encrypt(
source=plaintext, key_provider=key_provider, frame_length=frame_size
)

_look_in_logs(caplog, plaintext)
_error_check(capsys)


@pytest.mark.parametrize("frame_size", (0, LINE_LENGTH // 2, LINE_LENGTH, LINE_LENGTH * 2))
@pytest.mark.parametrize(
"plaintext_length", (1, LINE_LENGTH // 2, LINE_LENGTH, int(LINE_LENGTH * 1.5), LINE_LENGTH * 2)
)
def test_plaintext_logs_stream(caplog, plaintext_length, frame_size):
def test_plaintext_logs_stream(caplog, capsys, plaintext_length, frame_size):
plaintext, key_provider = _prep_plaintext_and_logs(caplog, plaintext_length)

ciphertext = b""
Expand All @@ -727,3 +744,4 @@ def test_plaintext_logs_stream(caplog, plaintext_length, frame_size):
ciphertext += line

_look_in_logs(caplog, plaintext)
_error_check(capsys)
23 changes: 23 additions & 0 deletions test/unit/test_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,26 @@ def test_master_key_info_convert(kwargs, attribute, expected_value):
test = MasterKeyInfo(**kwargs)

assert getattr(test, attribute) == expected_value


@pytest.mark.parametrize(
"cls, params",
(
(DataKey, ("key_provider", "data_key", "encrypted_data_key")),
(RawDataKey, ("key_provider", "data_key")),
(EncryptedDataKey, ("key_provider", "encrypted_data_key")),
),
)
def test_data_key_repr_str(cls, params):
data_key = b"plaintext data key ioasuwenvfiuawehnviuawh\x02\x99sd"
encrypted_data_key = b"encrypted data key josaidejoawuief\x02\x99sd"
base_params = dict(
key_provider=MasterKeyInfo(provider_id="asdf", key_info=b"fdsa"),
data_key=data_key,
encrypted_data_key=encrypted_data_key,
)
test = cls(**{key: base_params[key] for key in params})
data_key_check = repr(data_key)[2:-1]

assert data_key_check not in str(test)
assert data_key_check not in repr(test)

0 comments on commit fd07fbb

Please sign in to comment.