Skip to content

Commit

Permalink
Allow lax response parsing on Py parser (#7663)
Browse files Browse the repository at this point in the history
  • Loading branch information
Dreamsorcerer committed Oct 6, 2023
1 parent 454092d commit bd5f924
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 86 deletions.
1 change: 1 addition & 0 deletions CHANGES/7663.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Updated Python parser to comply with latest HTTP specs and allow lax response parsing -- by :user:`Dreamorcerer`
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ aiohttp/_find_header.c: $(call to-hash,aiohttp/hdrs.py ./tools/gen.py)

# _find_headers generator creates _headers.pyi as well
aiohttp/%.c: aiohttp/%.pyx $(call to-hash,$(CYS)) aiohttp/_find_header.c
cython -3 -o $@ $< -I aiohttp
cython -3 -o $@ $< -I aiohttp -Werror

vendor/llhttp/node_modules: vendor/llhttp/package.json
cd vendor/llhttp; npm install
Expand Down
86 changes: 66 additions & 20 deletions aiohttp/http_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@
from contextlib import suppress
from enum import IntEnum
from typing import (
Any,
ClassVar,
Final,
Generic,
List,
Literal,
NamedTuple,
Optional,
Pattern,
Expand All @@ -24,7 +27,7 @@
from . import hdrs
from .base_protocol import BaseProtocol
from .compression_utils import HAS_BROTLI, BrotliDecompressor, ZLibDecompressor
from .helpers import NO_EXTENSIONS, BaseTimerContext
from .helpers import DEBUG, NO_EXTENSIONS, BaseTimerContext
from .http_exceptions import (
BadHttpMessage,
BadStatusLine,
Expand All @@ -48,6 +51,8 @@
"RawResponseMessage",
)

_SEP = Literal[b"\r\n", b"\n"]

ASCIISET: Final[Set[str]] = set(string.printable)

# See https://www.rfc-editor.org/rfc/rfc9110.html#name-overview
Expand All @@ -60,6 +65,7 @@
METHRE: Final[Pattern[str]] = re.compile(r"[!#$%&'*+\-.^_`|~0-9A-Za-z]+")
VERSRE: Final[Pattern[str]] = re.compile(r"HTTP/(\d).(\d)")
HDRRE: Final[Pattern[bytes]] = re.compile(rb"[\x00-\x1F\x7F()<>@,;:\[\]={} \t\"\\]")
HEXDIGIT = re.compile(rb"[0-9a-fA-F]+")


class RawRequestMessage(NamedTuple):
Expand Down Expand Up @@ -206,6 +212,8 @@ def parse_headers(


class HttpParser(abc.ABC, Generic[_MsgT]):
lax: ClassVar[bool] = False

def __init__(
self,
protocol: BaseProtocol,
Expand Down Expand Up @@ -266,7 +274,7 @@ def feed_eof(self) -> Optional[_MsgT]:
def feed_data(
self,
data: bytes,
SEP: bytes = b"\r\n",
SEP: _SEP = b"\r\n",
EMPTY: bytes = b"",
CONTENT_LENGTH: istr = hdrs.CONTENT_LENGTH,
METH_CONNECT: str = hdrs.METH_CONNECT,
Expand All @@ -288,13 +296,16 @@ def feed_data(
pos = data.find(SEP, start_pos)
# consume \r\n
if pos == start_pos and not self._lines:
start_pos = pos + 2
start_pos = pos + len(SEP)
continue

if pos >= start_pos:
# line found
self._lines.append(data[start_pos:pos])
start_pos = pos + 2
line = data[start_pos:pos]
if SEP == b"\n": # For lax response parsing
line = line.rstrip(b"\r")
self._lines.append(line)
start_pos = pos + len(SEP)

# \r\n\r\n found
if self._lines[-1] == EMPTY:
Expand All @@ -311,7 +322,7 @@ def get_content_length() -> Optional[int]:

# Shouldn't allow +/- or other number formats.
# https://www.rfc-editor.org/rfc/rfc9110#section-8.6-2
if not length_hdr.strip(" \t").isdigit():
if not length_hdr.strip(" \t").isdecimal():
raise InvalidHeader(CONTENT_LENGTH)

return int(length_hdr)
Expand Down Expand Up @@ -348,6 +359,7 @@ def get_content_length() -> Optional[int]:
readall=self.readall,
response_with_body=self.response_with_body,
auto_decompress=self._auto_decompress,
lax=self.lax,
)
if not payload_parser.done:
self._payload_parser = payload_parser
Expand All @@ -366,6 +378,7 @@ def get_content_length() -> Optional[int]:
compression=msg.compression,
readall=True,
auto_decompress=self._auto_decompress,
lax=self.lax,
)
else:
if (
Expand All @@ -389,6 +402,7 @@ def get_content_length() -> Optional[int]:
readall=True,
response_with_body=self.response_with_body,
auto_decompress=self._auto_decompress,
lax=self.lax,
)
if not payload_parser.done:
self._payload_parser = payload_parser
Expand All @@ -411,7 +425,7 @@ def get_content_length() -> Optional[int]:
assert not self._lines
assert self._payload_parser is not None
try:
eof, data = self._payload_parser.feed_data(data[start_pos:])
eof, data = self._payload_parser.feed_data(data[start_pos:], SEP)
except BaseException as exc:
if self.payload_exception is not None:
self._payload_parser.payload.set_exception(
Expand Down Expand Up @@ -456,12 +470,21 @@ def parse_headers(

# https://www.rfc-editor.org/rfc/rfc9110.html#section-5.5-6
# https://www.rfc-editor.org/rfc/rfc9110.html#name-collected-abnf
singletons = (hdrs.CONTENT_LENGTH, hdrs.CONTENT_LOCATION, hdrs.CONTENT_RANGE,
hdrs.CONTENT_TYPE, hdrs.ETAG, hdrs.HOST, hdrs.MAX_FORWARDS,
hdrs.SERVER, hdrs.TRANSFER_ENCODING, hdrs.USER_AGENT)
singletons = (
hdrs.CONTENT_LENGTH,
hdrs.CONTENT_LOCATION,
hdrs.CONTENT_RANGE,
hdrs.CONTENT_TYPE,
hdrs.ETAG,
hdrs.HOST,
hdrs.MAX_FORWARDS,
hdrs.SERVER,
hdrs.TRANSFER_ENCODING,
hdrs.USER_AGENT,
)
bad_hdr = next((h for h in singletons if len(headers.getall(h, ())) > 1), None)
if bad_hdr is not None:
raise BadHttpMessage("Duplicate '{}' header found.".format(bad_hdr))
raise BadHttpMessage(f"Duplicate '{bad_hdr}' header found.")

# keep-alive
conn = headers.get(hdrs.CONNECTION)
Expand Down Expand Up @@ -597,6 +620,20 @@ class HttpResponseParser(HttpParser[RawResponseMessage]):
Returns RawResponseMessage.
"""

# Lax mode should only be enabled on response parser.
lax = not DEBUG

def feed_data(
self,
data: bytes,
SEP: Optional[_SEP] = None,
*args: Any,
**kwargs: Any,
) -> Tuple[List[Tuple[RawResponseMessage, StreamReader]], bool, bytes]:
if SEP is None:
SEP = b"\r\n" if DEBUG else b"\n"
return super().feed_data(data, SEP, *args, **kwargs)

def parse_message(self, lines: List[bytes]) -> RawResponseMessage:
line = lines[0].decode("utf-8", "surrogateescape")
try:
Expand All @@ -621,7 +658,7 @@ def parse_message(self, lines: List[bytes]) -> RawResponseMessage:
version_o = HttpVersion(int(match.group(1)), int(match.group(2)))

# The status code is a three-digit number
if len(status) != 3 or not status.isdigit():
if len(status) != 3 or not status.isdecimal():
raise BadStatusLine(line)
status_i = int(status)

Expand Down Expand Up @@ -663,13 +700,15 @@ def __init__(
readall: bool = False,
response_with_body: bool = True,
auto_decompress: bool = True,
lax: bool = False,
) -> None:
self._length = 0
self._type = ParseState.PARSE_NONE
self._chunk = ChunkState.PARSE_CHUNKED_SIZE
self._chunk_size = 0
self._chunk_tail = b""
self._auto_decompress = auto_decompress
self._lax = lax
self.done = False

# payload decompression wrapper
Expand Down Expand Up @@ -721,7 +760,7 @@ def feed_eof(self) -> None:
)

def feed_data(
self, chunk: bytes, SEP: bytes = b"\r\n", CHUNK_EXT: bytes = b";"
self, chunk: bytes, SEP: _SEP = b"\r\n", CHUNK_EXT: bytes = b";"
) -> Tuple[bool, bytes]:
# Read specified amount of bytes
if self._type == ParseState.PARSE_LENGTH:
Expand Down Expand Up @@ -757,17 +796,22 @@ def feed_data(
else:
size_b = chunk[:pos]

if not size_b.isdigit():
if self._lax: # Allow whitespace in lax mode.
size_b = size_b.strip()

if not re.fullmatch(HEXDIGIT, size_b):
exc = TransferEncodingError(
chunk[:pos].decode("ascii", "surrogateescape")
)
self.payload.set_exception(exc)
raise exc
size = int(bytes(size_b), 16)

chunk = chunk[pos + 2 :]
chunk = chunk[pos + len(SEP) :]
if size == 0: # eof marker
self._chunk = ChunkState.PARSE_MAYBE_TRAILERS
if self._lax and chunk.startswith(b"\r"):
chunk = chunk[1:]
else:
self._chunk = ChunkState.PARSE_CHUNKED_CHUNK
self._chunk_size = size
Expand All @@ -789,13 +833,15 @@ def feed_data(
self._chunk_size = 0
self.payload.feed_data(chunk[:required], required)
chunk = chunk[required:]
if self._lax and chunk.startswith(b"\r"):
chunk = chunk[1:]
self._chunk = ChunkState.PARSE_CHUNKED_CHUNK_EOF
self.payload.end_http_chunk_receiving()

# toss the CRLF at the end of the chunk
if self._chunk == ChunkState.PARSE_CHUNKED_CHUNK_EOF:
if chunk[:2] == SEP:
chunk = chunk[2:]
if chunk[: len(SEP)] == SEP:
chunk = chunk[len(SEP) :]
self._chunk = ChunkState.PARSE_CHUNKED_SIZE
else:
self._chunk_tail = chunk
Expand All @@ -805,11 +851,11 @@ def feed_data(
# we should get another \r\n otherwise
# trailers needs to be skiped until \r\n\r\n
if self._chunk == ChunkState.PARSE_MAYBE_TRAILERS:
head = chunk[:2]
head = chunk[: len(SEP)]
if head == SEP:
# end of stream
self.payload.feed_eof()
return True, chunk[2:]
return True, chunk[len(SEP) :]
# Both CR and LF, or only LF may not be received yet. It is
# expected that CRLF or LF will be shown at the very first
# byte next time, otherwise trailers should come. The last
Expand All @@ -827,7 +873,7 @@ def feed_data(
if self._chunk == ChunkState.PARSE_TRAILERS:
pos = chunk.find(SEP)
if pos >= 0:
chunk = chunk[pos + 2 :]
chunk = chunk[pos + len(SEP) :]
self._chunk = ChunkState.PARSE_MAYBE_TRAILERS
else:
self._chunk_tail = chunk
Expand Down

0 comments on commit bd5f924

Please sign in to comment.