Skip to content

Commit

Permalink
Preserve MultipartWriter parts headers on write (#3475)
Browse files Browse the repository at this point in the history
* Preserve MultipartWriter parts headers on write

This fixes #3035 

* Mark case when payload has no headers as unreachable with FIXME

This case is actually impossible because all the payload instances will
have headers defined with at least `Content-Type` definition.

While, it's theoretically possible to create `Payload` without headers
definition and Multipart format itself allows such parts, in multipart
module we already have set of assertions which wouldn't make this
possible.

Since `_binary_headers` is private property with unknown fate and
created just to not copy-paste headers serialization logic twice and
used exactly within `multipart` module, we're safe to ignore this
branch.

Proper fix would be refactoring the way how headers and their fragments
will get handled by `Payload` instances, but this quite a work out of
scope of current bugfix and will be addressed in upcoming PR.
  • Loading branch information
kxepal authored and asvetlov committed Jan 3, 2019
1 parent c77c058 commit 41274ea
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 14 deletions.
1 change: 1 addition & 0 deletions CHANGES/3035.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Preserve MultipartWriter parts headers on write.
17 changes: 6 additions & 11 deletions aiohttp/multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,7 @@ async def _maybe_release_last_part(self) -> None:
self._last_part = None


_Part = Tuple[Payload, 'MultiMapping[str]', str, str]
_Part = Tuple[Payload, str, str]


class MultipartWriter(Payload):
Expand Down Expand Up @@ -812,12 +812,7 @@ def append_payload(self, payload: Payload) -> Payload:
if size is not None and not (encoding or te_encoding):
payload.headers[CONTENT_LENGTH] = str(size)

# render headers
headers = ''.join(
[k + ': ' + v + '\r\n' for k, v in payload.headers.items()]
).encode('utf-8') + b'\r\n'

self._parts.append((payload, headers, encoding, te_encoding)) # type: ignore # noqa
self._parts.append((payload, encoding, te_encoding)) # type: ignore
return payload

def append_json(
Expand Down Expand Up @@ -858,13 +853,13 @@ def size(self) -> Optional[int]:
return 0

total = 0
for part, headers, encoding, te_encoding in self._parts:
for part, encoding, te_encoding in self._parts:
if encoding or te_encoding or part.size is None:
return None

total += int(
2 + len(self._boundary) + 2 + # b'--'+self._boundary+b'\r\n'
part.size + len(headers) +
part.size + len(part._binary_headers) +
2 # b'\r\n'
)

Expand All @@ -877,9 +872,9 @@ async def write(self, writer: Any,
if not self._parts:
return

for part, headers, encoding, te_encoding in self._parts:
for part, encoding, te_encoding in self._parts:
await writer.write(b'--' + self._boundary + b'\r\n')
await writer.write(headers)
await writer.write(part._binary_headers)

if encoding or te_encoding:
w = MultipartPayloadWriter(writer)
Expand Down
9 changes: 9 additions & 0 deletions aiohttp/payload.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,15 @@ def headers(self) -> Optional[_CIMultiDict]:
"""Custom item headers"""
return self._headers

@property
def _binary_headers(self) -> bytes:
if self.headers is None:
# FIXME: This case actually is unreachable.
return b'' # pragma: no cover
return ''.join(
[k + ': ' + v + '\r\n' for k, v in self.headers.items()]
).encode('utf-8') + b'\r\n'

@property
def encoding(self) -> Optional[str]:
"""Payload encoding"""
Expand Down
22 changes: 19 additions & 3 deletions tests/test_multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -1008,9 +1008,6 @@ def test_append_multipart(self, writer) -> None:
part = writer._parts[0][0]
assert part.headers[CONTENT_TYPE] == 'test/passed'

async def test_write(self, writer, stream) -> None:
await writer.write(stream)

def test_with(self) -> None:
with aiohttp.MultipartWriter(boundary=':') as writer:
writer.append('foo')
Expand All @@ -1033,6 +1030,25 @@ def test_append_none_not_allowed(self) -> None:
with aiohttp.MultipartWriter(boundary=':') as writer:
writer.append(None)

async def test_write_preserves_content_disposition(
self, buf, stream
) -> None:
with aiohttp.MultipartWriter(boundary=':') as writer:
part = writer.append(b'foo', headers={CONTENT_TYPE: 'test/passed'})
part.set_content_disposition('form-data', filename='bug')
await writer.write(stream)

headers, message = bytes(buf).split(b'\r\n\r\n', 1)

assert headers == (
b'--:\r\n'
b'Content-Type: test/passed\r\n'
b'Content-Length: 3\r\n'
b'Content-Disposition:'
b' form-data; filename="bug"; filename*=utf-8\'\'bug'
)
assert message == b'foo\r\n--:--\r\n'


async def test_async_for_reader() -> None:
data = [
Expand Down

0 comments on commit 41274ea

Please sign in to comment.