Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix handling of multipart/form-data (#8280) #8301

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES/8280.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed ``multipart/form-data`` compliance with :rfc:`7578` -- by :user:`Dreamsorcerer`.
2 changes: 2 additions & 0 deletions CHANGES/8280.deprecation.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Deprecated ``content_transfer_encoding`` parameter in :py:meth:`FormData.add_field()
<aiohttp.FormData.add_field>` -- by :user:`Dreamsorcerer`.
12 changes: 11 additions & 1 deletion aiohttp/formdata.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import io
import warnings
from typing import Any, Iterable, List, Optional
from urllib.parse import urlencode

Expand Down Expand Up @@ -53,7 +54,12 @@
if isinstance(value, io.IOBase):
self._is_multipart = True
elif isinstance(value, (bytes, bytearray, memoryview)):
msg = (
"In v4, passing bytes will no longer create a file field. "
"Please explicitly use the filename parameter or pass a BytesIO object."
)
if filename is None and content_transfer_encoding is None:
warnings.warn(msg, DeprecationWarning)

Check warning on line 62 in aiohttp/formdata.py

View check run for this annotation

Codecov / codecov/patch

aiohttp/formdata.py#L62

Added line #L62 was not covered by tests
filename = name

type_options: MultiDict[str] = MultiDict({"name": name})
Expand Down Expand Up @@ -81,7 +87,11 @@
"content_transfer_encoding must be an instance"
" of str. Got: %s" % content_transfer_encoding
)
headers[hdrs.CONTENT_TRANSFER_ENCODING] = content_transfer_encoding
msg = (
"content_transfer_encoding is deprecated. "
"To maintain compatibility with v4 please pass a BytesPayload."
)
warnings.warn(msg, DeprecationWarning)
self._is_multipart = True

self._fields.append((type_options, headers, value))
Expand Down
121 changes: 80 additions & 41 deletions aiohttp/multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,13 +256,22 @@
chunk_size = 8192

def __init__(
self, boundary: bytes, headers: "CIMultiDictProxy[str]", content: StreamReader
self,
boundary: bytes,
headers: "CIMultiDictProxy[str]",
content: StreamReader,
*,
subtype: str = "mixed",
default_charset: Optional[str] = None,
) -> None:
self.headers = headers
self._boundary = boundary
self._content = content
self._default_charset = default_charset
self._at_eof = False
length = self.headers.get(CONTENT_LENGTH, None)
self._is_form_data = subtype == "form-data"
# https://datatracker.ietf.org/doc/html/rfc7578#section-4.8
length = None if self._is_form_data else self.headers.get(CONTENT_LENGTH, None)
self._length = int(length) if length is not None else None
self._read_bytes = 0
self._unread: Deque[bytes] = deque()
Expand Down Expand Up @@ -329,6 +338,8 @@
assert self._length is not None, "Content-Length required for chunked read"
chunk_size = min(size, self._length - self._read_bytes)
chunk = await self._content.read(chunk_size)
if self._content.at_eof():
self._at_eof = True

Check warning on line 342 in aiohttp/multipart.py

View check run for this annotation

Codecov / codecov/patch

aiohttp/multipart.py#L342

Added line #L342 was not covered by tests
return chunk

async def _read_chunk_from_stream(self, size: int) -> bytes:
Expand Down Expand Up @@ -449,7 +460,8 @@
"""
if CONTENT_TRANSFER_ENCODING in self.headers:
data = self._decode_content_transfer(data)
if CONTENT_ENCODING in self.headers:
# https://datatracker.ietf.org/doc/html/rfc7578#section-4.8
if not self._is_form_data and CONTENT_ENCODING in self.headers:
return self._decode_content(data)
return data

Expand Down Expand Up @@ -483,7 +495,7 @@
"""Returns charset parameter from Content-Type header or default."""
ctype = self.headers.get(CONTENT_TYPE, "")
mimetype = parse_mimetype(ctype)
return mimetype.parameters.get("charset", default)
return mimetype.parameters.get("charset", self._default_charset or default)

@reify
def name(self) -> Optional[str]:
Expand Down Expand Up @@ -538,9 +550,17 @@
part_reader_cls = BodyPartReader

def __init__(self, headers: Mapping[str, str], content: StreamReader) -> None:
self._mimetype = parse_mimetype(headers[CONTENT_TYPE])
assert self._mimetype.type == "multipart", "multipart/* content type expected"
if "boundary" not in self._mimetype.parameters:
raise ValueError(

Check warning on line 556 in aiohttp/multipart.py

View check run for this annotation

Codecov / codecov/patch

aiohttp/multipart.py#L556

Added line #L556 was not covered by tests
"boundary missed for Content-Type: %s" % headers[CONTENT_TYPE]
)

self.headers = headers
self._boundary = ("--" + self._get_boundary()).encode()
self._content = content
self._default_charset: Optional[str] = None
self._last_part: Optional[Union["MultipartReader", BodyPartReader]] = None
self._at_eof = False
self._at_bof = True
Expand Down Expand Up @@ -592,7 +612,24 @@
await self._read_boundary()
if self._at_eof: # we just read the last boundary, nothing to do there
return None
self._last_part = await self.fetch_next_part()

part = await self.fetch_next_part()
# https://datatracker.ietf.org/doc/html/rfc7578#section-4.6
if (
self._last_part is None
and self._mimetype.subtype == "form-data"
and isinstance(part, BodyPartReader)
):
_, params = parse_content_disposition(part.headers.get(CONTENT_DISPOSITION))
if params.get("name") == "_charset_":
# Longest encoding in https://encoding.spec.whatwg.org/encodings.json
# is 19 characters, so 32 should be more than enough for any valid encoding.
charset = await part.read_chunk(32)
if len(charset) > 31:
raise RuntimeError("Invalid default charset")
self._default_charset = charset.strip().decode()
part = await self.fetch_next_part()
self._last_part = part
return self._last_part

async def release(self) -> None:
Expand Down Expand Up @@ -628,19 +665,16 @@
return type(self)(headers, self._content)
return self.multipart_reader_cls(headers, self._content)
else:
return self.part_reader_cls(self._boundary, headers, self._content)

def _get_boundary(self) -> str:
mimetype = parse_mimetype(self.headers[CONTENT_TYPE])

assert mimetype.type == "multipart", "multipart/* content type expected"

if "boundary" not in mimetype.parameters:
raise ValueError(
"boundary missed for Content-Type: %s" % self.headers[CONTENT_TYPE]
return self.part_reader_cls(
self._boundary,
headers,
self._content,
subtype=self._mimetype.subtype,
default_charset=self._default_charset,
)

boundary = mimetype.parameters["boundary"]
def _get_boundary(self) -> str:
boundary = self._mimetype.parameters["boundary"]
if len(boundary) > 70:
raise ValueError("boundary %r is too long (70 chars max)" % boundary)

Expand Down Expand Up @@ -731,6 +765,7 @@
super().__init__(None, content_type=ctype)

self._parts: List[_Part] = []
self._is_form_data = subtype == "form-data"

def __enter__(self) -> "MultipartWriter":
return self
Expand Down Expand Up @@ -808,32 +843,36 @@

def append_payload(self, payload: Payload) -> Payload:
"""Adds a new body part to multipart writer."""
# compression
encoding: Optional[str] = payload.headers.get(
CONTENT_ENCODING,
"",
).lower()
if encoding and encoding not in ("deflate", "gzip", "identity"):
raise RuntimeError(f"unknown content encoding: {encoding}")
if encoding == "identity":
encoding = None

# te encoding
te_encoding: Optional[str] = payload.headers.get(
CONTENT_TRANSFER_ENCODING,
"",
).lower()
if te_encoding not in ("", "base64", "quoted-printable", "binary"):
raise RuntimeError(
"unknown content transfer encoding: {}" "".format(te_encoding)
encoding: Optional[str] = None
te_encoding: Optional[str] = None
if self._is_form_data:
# https://datatracker.ietf.org/doc/html/rfc7578#section-4.7
# https://datatracker.ietf.org/doc/html/rfc7578#section-4.8
assert CONTENT_DISPOSITION in payload.headers
assert "name=" in payload.headers[CONTENT_DISPOSITION]
assert (
not {CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_TRANSFER_ENCODING}
& payload.headers.keys()
)
if te_encoding == "binary":
te_encoding = None

# size
size = payload.size
if size is not None and not (encoding or te_encoding):
payload.headers[CONTENT_LENGTH] = str(size)
else:
# compression
encoding = payload.headers.get(CONTENT_ENCODING, "").lower()
if encoding and encoding not in ("deflate", "gzip", "identity"):
raise RuntimeError(f"unknown content encoding: {encoding}")
if encoding == "identity":
encoding = None

# te encoding
te_encoding = payload.headers.get(CONTENT_TRANSFER_ENCODING, "").lower()
if te_encoding not in ("", "base64", "quoted-printable", "binary"):
raise RuntimeError(f"unknown content transfer encoding: {te_encoding}")
if te_encoding == "binary":
te_encoding = None

# size
size = payload.size
if size is not None and not (encoding or te_encoding):
payload.headers[CONTENT_LENGTH] = str(size)

self._parts.append((payload, encoding, te_encoding)) # type: ignore[arg-type]
return payload
Expand Down
44 changes: 1 addition & 43 deletions tests/test_client_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1387,48 +1387,6 @@ async def handler(request):
resp.close()


async def test_POST_DATA_with_context_transfer_encoding(aiohttp_client) -> None:
async def handler(request):
data = await request.post()
assert data["name"] == "text"
return web.Response(text=data["name"])

app = web.Application()
app.router.add_post("/", handler)
client = await aiohttp_client(app)

form = aiohttp.FormData()
form.add_field("name", "text", content_transfer_encoding="base64")

resp = await client.post("/", data=form)
assert 200 == resp.status
content = await resp.text()
assert content == "text"
resp.close()


async def test_POST_DATA_with_content_type_context_transfer_encoding(aiohttp_client):
async def handler(request):
data = await request.post()
assert data["name"] == "text"
return web.Response(body=data["name"])

app = web.Application()
app.router.add_post("/", handler)
client = await aiohttp_client(app)

form = aiohttp.FormData()
form.add_field(
"name", "text", content_type="text/plain", content_transfer_encoding="base64"
)

resp = await client.post("/", data=form)
assert 200 == resp.status
content = await resp.text()
assert content == "text"
resp.close()


async def test_POST_MultiDict(aiohttp_client) -> None:
async def handler(request):
data = await request.post()
Expand Down Expand Up @@ -1480,7 +1438,7 @@ async def handler(request):

with fname.open("rb") as f:
async with client.post(
"/", data={"some": f, "test": b"data"}, chunked=True
"/", data={"some": f, "test": io.BytesIO(b"data")}, chunked=True
) as resp:
assert 200 == resp.status

Expand Down
Loading
Loading