Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 2 additions & 25 deletions src/mock_vws/_query_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import io
import uuid
from email.message import EmailMessage
from typing import IO, TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any
from zoneinfo import ZoneInfo

from werkzeug.formparser import MultiPartParser
Expand All @@ -20,33 +20,10 @@
from mock_vws._mock_common import json_dump

if TYPE_CHECKING:
from werkzeug.datastructures import FileStorage, MultiDict

from mock_vws.database import VuforiaDatabase
from mock_vws.image_matchers import ImageMatcher


class TypedMultiPartParser(MultiPartParser):
"""
A MultiPartParser which returns types for fields.

This is a workaround for https://github.com/pallets/werkzeug/pull/2841.
"""

def parse(
self,
stream: IO[bytes],
boundary: bytes,
content_length: int | None,
) -> tuple[MultiDict[str, str], MultiDict[str, FileStorage]]:
# Once this Pyright issue is fixed, we can remove this whole class.
return super().parse( # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType]
stream=stream,
boundary=boundary,
content_length=content_length,
)


def get_query_match_response_text(
request_headers: dict[str, str],
request_body: bytes,
Expand All @@ -73,7 +50,7 @@ def get_query_match_response_text(
boundary = email_message.get_boundary()
assert isinstance(boundary, str)

parser = TypedMultiPartParser()
parser = MultiPartParser()
fields, files = parser.parse(
stream=io.BytesIO(request_body),
boundary=boundary.encode("utf-8"),
Expand Down
5 changes: 3 additions & 2 deletions src/mock_vws/_query_validators/fields_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import logging
from email.message import EmailMessage

from mock_vws._query_tools import TypedMultiPartParser
from werkzeug.formparser import MultiPartParser

from mock_vws._query_validators.exceptions import UnknownParameters

_LOGGER = logging.getLogger(__name__)
Expand All @@ -30,7 +31,7 @@ def validate_extra_fields(
email_message["Content-Type"] = request_headers["Content-Type"]
boundary = email_message.get_boundary()
assert isinstance(boundary, str)
parser = TypedMultiPartParser()
parser = MultiPartParser()
fields, files = parser.parse(
stream=io.BytesIO(request_body),
boundary=boundary.encode("utf-8"),
Expand Down
12 changes: 6 additions & 6 deletions src/mock_vws/_query_validators/image_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from email.message import EmailMessage

from PIL import Image
from werkzeug.formparser import MultiPartParser

from mock_vws._query_tools import TypedMultiPartParser
from mock_vws._query_validators.exceptions import (
BadImage,
ImageNotGiven,
Expand Down Expand Up @@ -36,7 +36,7 @@ def validate_image_field_given(
email_message["Content-Type"] = request_headers["Content-Type"]
boundary = email_message.get_boundary()
assert isinstance(boundary, str)
parser = TypedMultiPartParser()
parser = MultiPartParser()
_, files = parser.parse(
stream=io.BytesIO(request_body),
boundary=boundary.encode("utf-8"),
Expand Down Expand Up @@ -67,7 +67,7 @@ def validate_image_file_size(
email_message["Content-Type"] = request_headers["Content-Type"]
boundary = email_message.get_boundary()
assert isinstance(boundary, str)
parser = TypedMultiPartParser()
parser = MultiPartParser()
_, files = parser.parse(
stream=io.BytesIO(request_body),
boundary=boundary.encode("utf-8"),
Expand Down Expand Up @@ -108,7 +108,7 @@ def validate_image_dimensions(
email_message["Content-Type"] = request_headers["Content-Type"]
boundary = email_message.get_boundary()
assert isinstance(boundary, str)
parser = TypedMultiPartParser()
parser = MultiPartParser()
_, files = parser.parse(
stream=io.BytesIO(request_body),
boundary=boundary.encode("utf-8"),
Expand Down Expand Up @@ -145,7 +145,7 @@ def validate_image_format(
email_message["Content-Type"] = request_headers["Content-Type"]
boundary = email_message.get_boundary()
assert isinstance(boundary, str)
parser = TypedMultiPartParser()
parser = MultiPartParser()
_, files = parser.parse(
stream=io.BytesIO(request_body),
boundary=boundary.encode("utf-8"),
Expand Down Expand Up @@ -179,7 +179,7 @@ def validate_image_is_image(
email_message["Content-Type"] = request_headers["Content-Type"]
boundary = email_message.get_boundary()
assert isinstance(boundary, str)
parser = TypedMultiPartParser()
parser = MultiPartParser()
_, files = parser.parse(
stream=io.BytesIO(request_body),
boundary=boundary.encode("utf-8"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import logging
from email.message import EmailMessage

from mock_vws._query_tools import TypedMultiPartParser
from werkzeug.formparser import MultiPartParser

from mock_vws._query_validators.exceptions import InvalidIncludeTargetData

_LOGGER = logging.getLogger(__name__)
Expand All @@ -32,7 +33,7 @@ def validate_include_target_data(
email_message["Content-Type"] = request_headers["Content-Type"]
boundary = email_message.get_boundary()
assert isinstance(boundary, str)
parser = TypedMultiPartParser()
parser = MultiPartParser()
fields, _ = parser.parse(
stream=io.BytesIO(request_body),
boundary=boundary.encode("utf-8"),
Expand Down
5 changes: 3 additions & 2 deletions src/mock_vws/_query_validators/num_results_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import logging
from email.message import EmailMessage

from mock_vws._query_tools import TypedMultiPartParser
from werkzeug.formparser import MultiPartParser

from mock_vws._query_validators.exceptions import (
InvalidMaxNumResults,
MaxNumResultsOutOfRange,
Expand Down Expand Up @@ -36,7 +37,7 @@ def validate_max_num_results(
email_message["Content-Type"] = request_headers["Content-Type"]
boundary = email_message.get_boundary()
assert isinstance(boundary, str)
parser = TypedMultiPartParser()
parser = MultiPartParser()
fields, _ = parser.parse(
stream=io.BytesIO(request_body),
boundary=boundary.encode("utf-8"),
Expand Down