Skip to content

Commit

Permalink
Adds support for checksums in streamed request trailers (#962)
Browse files Browse the repository at this point in the history
  • Loading branch information
terrycain committed Nov 29, 2022
1 parent 551343c commit 8d9d71a
Show file tree
Hide file tree
Showing 10 changed files with 191 additions and 34 deletions.
4 changes: 4 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
Changes
-------
2.4.1 (2022-11-28)
^^^^^^^^^^^^^^^^^^
* Adds support for checksums in streamed request trailers (thanks @terrycain #962)

2.4.0 (2022-08-25)
^^^^^^^^^^^^^^^^^^
* bump botocore to 1.27.59
Expand Down
2 changes: 1 addition & 1 deletion aiobotocore/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '2.4.0'
__version__ = '2.4.1'
2 changes: 1 addition & 1 deletion aiobotocore/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
PaginatorDocstring,
S3ArnParamHandler,
S3EndpointSetter,
apply_request_checksum,
logger,
resolve_checksum_context,
)
Expand All @@ -20,6 +19,7 @@
from . import waiter
from .args import AioClientArgsCreator
from .discovery import AioEndpointDiscoveryHandler, AioEndpointDiscoveryManager
from .httpchecksum import apply_request_checksum
from .paginate import AioPaginator
from .retries import adaptive, standard
from .utils import AioS3RegionRedirector
Expand Down
91 changes: 91 additions & 0 deletions aiobotocore/httpchecksum.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,48 @@
import io

from botocore.httpchecksum import (
_CHECKSUM_CLS,
AwsChunkedWrapper,
FlexibleChecksumError,
_apply_request_header_checksum,
_handle_streaming_response,
base64,
conditionally_calculate_md5,
determine_content_length,
logger,
)

from aiobotocore._helpers import resolve_awaitable


class AioAwsChunkedWrapper(AwsChunkedWrapper):
async def _make_chunk(self):
# NOTE: Chunk size is not deterministic as read could return less. This
# means we cannot know the content length of the encoded aws-chunked
# stream ahead of time without ensuring a consistent chunk size

raw_chunk = await resolve_awaitable(self._raw.read(self._chunk_size))
hex_len = hex(len(raw_chunk))[2:].encode("ascii")
self._complete = not raw_chunk

if self._checksum:
self._checksum.update(raw_chunk)

if self._checksum and self._complete:
name = self._checksum_name.encode("ascii")
checksum = self._checksum.b64digest().encode("ascii")
return b"0\r\n%s:%s\r\n\r\n" % (name, checksum)

return b"%s\r\n%s\r\n" % (hex_len, raw_chunk)

def __aiter__(self):
return self

async def __anext__(self):
while not self._complete:
return await self._make_chunk()
raise StopAsyncIteration()


async def handle_checksum_body(
http_response, response, context, operation_model
Expand Down Expand Up @@ -67,3 +104,57 @@ async def _handle_bytes_response(http_response, response, algorithm):
)
raise FlexibleChecksumError(error_msg=error_msg)
return body


def apply_request_checksum(request):
checksum_context = request.get("context", {}).get("checksum", {})
algorithm = checksum_context.get("request_algorithm")

if not algorithm:
return

if algorithm == "conditional-md5":
# Special case to handle the http checksum required trait
conditionally_calculate_md5(request)
elif algorithm["in"] == "header":
_apply_request_header_checksum(request)
elif algorithm["in"] == "trailer":
_apply_request_trailer_checksum(request)
else:
raise FlexibleChecksumError(
error_msg="Unknown checksum variant: %s" % algorithm["in"]
)


def _apply_request_trailer_checksum(request):
checksum_context = request.get("context", {}).get("checksum", {})
algorithm = checksum_context.get("request_algorithm")
location_name = algorithm["name"]
checksum_cls = _CHECKSUM_CLS.get(algorithm["algorithm"])

headers = request["headers"]
body = request["body"]

if location_name in headers:
# If the header is already set by the customer, skip calculation
return

# Cannot set this as aiohttp complains
headers["Transfer-Encoding"] = "chunked"
headers["Content-Encoding"] = "aws-chunked"
headers["X-Amz-Trailer"] = location_name

content_length = determine_content_length(body)
if content_length is not None:
# Send the decoded content length if we can determine it. Some
# services such as S3 may require the decoded content length
headers["X-Amz-Decoded-Content-Length"] = str(content_length)

if isinstance(body, (bytes, bytearray)):
body = io.BytesIO(body)

request["body"] = AioAwsChunkedWrapper(
body,
checksum_cls=checksum_cls,
checksum_name=location_name,
)
17 changes: 12 additions & 5 deletions aiobotocore/httpsession.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
parse_url,
urlparse,
)
from multidict import MultiDict
from multidict import CIMultiDict

import aiobotocore.awsrequest
from aiobotocore._endpoint_helpers import _IOBaseWrapper, _text
Expand Down Expand Up @@ -188,20 +188,27 @@ async def send(self, request):
host = urlparse(request.url).hostname
proxy_headers['host'] = host

# https://github.com/boto/botocore/issues/1255
headers['Accept-Encoding'] = 'identity'

headers_ = MultiDict(
headers_ = CIMultiDict(
(z[0], _text(z[1], encoding='utf-8')) for z in headers.items()
)

# https://github.com/boto/botocore/issues/1255
headers_['Accept-Encoding'] = 'identity'

chunked = None
if headers_.get('Transfer-Encoding', '').lower() == 'chunked':
# aiohttp wants chunking as a param, and not a header
headers_.pop('Transfer-Encoding', '')
chunked = True

if isinstance(data, io.IOBase):
data = _IOBaseWrapper(data)

url = URL(url, encoded=True)
response = await self._session.request(
request.method,
url=url,
chunked=chunked,
headers=headers_,
data=data,
proxy=proxy_url,
Expand Down
5 changes: 5 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,11 @@ def signature_version():
return 's3'


@pytest.fixture
def server_scheme():
return 'http'


@pytest.fixture
def s3_verify():
return None
Expand Down
46 changes: 24 additions & 22 deletions tests/mock_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,66 +97,68 @@ async def _wait_until_up(self):


@pytest.fixture
async def s3_server():
async with MotoService('s3') as svc:
async def s3_server(server_scheme):
async with MotoService('s3', ssl=server_scheme == 'https') as svc:
yield svc.endpoint_url


@pytest.fixture
async def dynamodb2_server():
async with MotoService('dynamodb') as svc:
async def dynamodb2_server(server_scheme):
async with MotoService('dynamodb', ssl=server_scheme == 'https') as svc:
yield svc.endpoint_url


@pytest.fixture
async def cloudformation_server():
async with MotoService('cloudformation') as svc:
async def cloudformation_server(server_scheme):
async with MotoService(
'cloudformation', ssl=server_scheme == 'https'
) as svc:
yield svc.endpoint_url


@pytest.fixture
async def sns_server():
async with MotoService('sns') as svc:
async def sns_server(server_scheme):
async with MotoService('sns', ssl=server_scheme == 'https') as svc:
yield svc.endpoint_url


@pytest.fixture
async def sqs_server():
async with MotoService('sqs') as svc:
async def sqs_server(server_scheme):
async with MotoService('sqs', ssl=server_scheme == 'https') as svc:
yield svc.endpoint_url


@pytest.fixture
async def batch_server():
async with MotoService('batch') as svc:
async def batch_server(server_scheme):
async with MotoService('batch', ssl=server_scheme == 'https') as svc:
yield svc.endpoint_url


@pytest.fixture
async def lambda_server():
async with MotoService('lambda') as svc:
async def lambda_server(server_scheme):
async with MotoService('lambda', ssl=server_scheme == 'https') as svc:
yield svc.endpoint_url


@pytest.fixture
async def iam_server():
async with MotoService('iam') as svc:
async def iam_server(server_scheme):
async with MotoService('iam', ssl=server_scheme == 'https') as svc:
yield svc.endpoint_url


@pytest.fixture
async def rds_server():
async with MotoService('rds') as svc:
async def rds_server(server_scheme):
async with MotoService('rds', ssl=server_scheme == 'https') as svc:
yield svc.endpoint_url


@pytest.fixture
async def ec2_server():
async with MotoService('ec2') as svc:
async def ec2_server(server_scheme):
async with MotoService('ec2', ssl=server_scheme == 'https') as svc:
yield svc.endpoint_url


@pytest.fixture
async def kinesis_server():
async with MotoService('kinesis') as svc:
async def kinesis_server(server_scheme):
async with MotoService('kinesis', ssl=server_scheme == 'https') as svc:
yield svc.endpoint_url
18 changes: 14 additions & 4 deletions tests/moto_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class MotoService:

_services = dict() # {name: instance}

def __init__(self, service_name: str, port: int = None):
def __init__(self, service_name: str, port: int = None, ssl: bool = False):
self._service_name = service_name

if port:
Expand All @@ -49,10 +49,14 @@ def __init__(self, service_name: str, port: int = None):
self._refcount = None
self._ip_address = host
self._server = None
self._ssl_ctx = (
werkzeug.serving.generate_adhoc_ssl_context() if ssl else None
)
self._schema = 'http' if not self._ssl_ctx else 'https'

@property
def endpoint_url(self):
return f'http://{self._ip_address}:{self._port}'
return f'{self._schema}://{self._ip_address}:{self._port}'

def __call__(self, func):
async def wrapper(*args, **kwargs):
Expand Down Expand Up @@ -100,7 +104,11 @@ def _server_entry(self):
self._socket = None

self._server = werkzeug.serving.make_server(
self._ip_address, self._port, self._main_app, True
self._ip_address,
self._port,
self._main_app,
True,
ssl_context=self._ssl_ctx,
)
self._server.serve_forever()

Expand All @@ -118,7 +126,9 @@ async def _start(self):
try:
# we need to bypass the proxies due to monkeypatches
async with session.get(
self.endpoint_url + '/static', timeout=_CONNECT_TIMEOUT
self.endpoint_url + '/static',
timeout=_CONNECT_TIMEOUT,
verify_ssl=False,
):
pass
break
Expand Down
24 changes: 24 additions & 0 deletions tests/test_basic_s3.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import asyncio
import base64
import hashlib
from collections import defaultdict

import aioitertools
Expand Down Expand Up @@ -633,3 +635,25 @@ async def test_head_object_keys(s3_client, create_object, bucket_name):
'ContentLength',
'VersionId',
}


@pytest.mark.xfail(
reason="moto does not yet support Checksum: https://github.com/spulec/moto/issues/5719"
)
@pytest.mark.parametrize('server_scheme', ['https'])
@pytest.mark.parametrize('s3_verify', [False])
@pytest.mark.moto
@pytest.mark.asyncio
async def test_put_object_sha256(s3_client, bucket_name):
data = b'test1234'
digest = hashlib.sha256(data).digest().hex()

resp = await s3_client.put_object(
Bucket=bucket_name,
Key='foobarbaz',
Body=data,
ChecksumAlgorithm='SHA256',
)
sha256_trailer_checksum = base64.b64decode(resp['ChecksumSHA256'])

assert digest == sha256_trailer_checksum
16 changes: 15 additions & 1 deletion tests/test_patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,13 @@
parse_get_bucket_location,
)
from botocore.hooks import EventAliaser, HierarchicalEmitter
from botocore.httpchecksum import _handle_bytes_response, handle_checksum_body
from botocore.httpchecksum import (
AwsChunkedWrapper,
_apply_request_trailer_checksum,
_handle_bytes_response,
apply_request_checksum,
handle_checksum_body,
)
from botocore.httpsession import URLLib3Session
from botocore.paginate import PageIterator, ResultKeyIterator
from botocore.parsers import (
Expand Down Expand Up @@ -557,6 +563,14 @@
# httpchecksum.py
handle_checksum_body: {'4b9aeef18d816563624c66c57126d1ffa6fe1993'},
_handle_bytes_response: {'0761c4590c6addbe8c674e40fca9f7dd375a184b'},
AwsChunkedWrapper._make_chunk: {
'097361692f0fd6c863a17dd695739629982ef7e4'
},
AwsChunkedWrapper.__iter__: {'261e26d1061655555fe3dcb2689d963e43f80fb0'},
apply_request_checksum: {'bcc044f0655f30769994efab72b29e76d73f7e39'},
_apply_request_trailer_checksum: {
'55c36eaf4701a379fcdbd78d0b7a831e5023a76e'
},
# retryhandler.py
retryhandler.create_retry_handler: {
'8fee36ed89d789194585f56b8dd4f525985a5811'
Expand Down

0 comments on commit 8d9d71a

Please sign in to comment.