diff --git a/src/mock_vws/_query_tools.py b/src/mock_vws/_query_tools.py index a177b72ae..e0aa6fef9 100644 --- a/src/mock_vws/_query_tools.py +++ b/src/mock_vws/_query_tools.py @@ -9,7 +9,7 @@ import io import uuid from email.message import EmailMessage -from typing import IO, TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any from zoneinfo import ZoneInfo from werkzeug.formparser import MultiPartParser @@ -20,33 +20,10 @@ from mock_vws._mock_common import json_dump if TYPE_CHECKING: - from werkzeug.datastructures import FileStorage, MultiDict - from mock_vws.database import VuforiaDatabase from mock_vws.image_matchers import ImageMatcher -class TypedMultiPartParser(MultiPartParser): - """ - A MultiPartParser which returns types for fields. - - This is a workaround for https://github.com/pallets/werkzeug/pull/2841. - """ - - def parse( - self, - stream: IO[bytes], - boundary: bytes, - content_length: int | None, - ) -> tuple[MultiDict[str, str], MultiDict[str, FileStorage]]: - # Once this Pyright issue is fixed, we can remove this whole class. - return super().parse( # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType] - stream=stream, - boundary=boundary, - content_length=content_length, - ) - - def get_query_match_response_text( request_headers: dict[str, str], request_body: bytes, @@ -73,7 +50,7 @@ def get_query_match_response_text( boundary = email_message.get_boundary() assert isinstance(boundary, str) - parser = TypedMultiPartParser() + parser = MultiPartParser() fields, files = parser.parse( stream=io.BytesIO(request_body), boundary=boundary.encode("utf-8"), diff --git a/src/mock_vws/_query_validators/fields_validators.py b/src/mock_vws/_query_validators/fields_validators.py index 2f0fcb3ef..42e964f36 100644 --- a/src/mock_vws/_query_validators/fields_validators.py +++ b/src/mock_vws/_query_validators/fields_validators.py @@ -6,7 +6,8 @@ import logging from email.message import EmailMessage -from mock_vws._query_tools import TypedMultiPartParser +from werkzeug.formparser import MultiPartParser + from mock_vws._query_validators.exceptions import UnknownParameters _LOGGER = logging.getLogger(__name__) @@ -30,7 +31,7 @@ def validate_extra_fields( email_message["Content-Type"] = request_headers["Content-Type"] boundary = email_message.get_boundary() assert isinstance(boundary, str) - parser = TypedMultiPartParser() + parser = MultiPartParser() fields, files = parser.parse( stream=io.BytesIO(request_body), boundary=boundary.encode("utf-8"), diff --git a/src/mock_vws/_query_validators/image_validators.py b/src/mock_vws/_query_validators/image_validators.py index 250bbc96b..ca1ac76e0 100644 --- a/src/mock_vws/_query_validators/image_validators.py +++ b/src/mock_vws/_query_validators/image_validators.py @@ -7,8 +7,8 @@ from email.message import EmailMessage from PIL import Image +from werkzeug.formparser import MultiPartParser -from mock_vws._query_tools import TypedMultiPartParser from mock_vws._query_validators.exceptions import ( BadImage, ImageNotGiven, @@ -36,7 +36,7 @@ def validate_image_field_given( email_message["Content-Type"] = request_headers["Content-Type"] boundary = email_message.get_boundary() assert isinstance(boundary, str) - parser = TypedMultiPartParser() + parser = MultiPartParser() _, files = parser.parse( stream=io.BytesIO(request_body), boundary=boundary.encode("utf-8"), @@ -67,7 +67,7 @@ def validate_image_file_size( email_message["Content-Type"] = request_headers["Content-Type"] boundary = email_message.get_boundary() assert isinstance(boundary, str) - parser = TypedMultiPartParser() + parser = MultiPartParser() _, files = parser.parse( stream=io.BytesIO(request_body), boundary=boundary.encode("utf-8"), @@ -108,7 +108,7 @@ def validate_image_dimensions( email_message["Content-Type"] = request_headers["Content-Type"] boundary = email_message.get_boundary() assert isinstance(boundary, str) - parser = TypedMultiPartParser() + parser = MultiPartParser() _, files = parser.parse( stream=io.BytesIO(request_body), boundary=boundary.encode("utf-8"), @@ -145,7 +145,7 @@ def validate_image_format( email_message["Content-Type"] = request_headers["Content-Type"] boundary = email_message.get_boundary() assert isinstance(boundary, str) - parser = TypedMultiPartParser() + parser = MultiPartParser() _, files = parser.parse( stream=io.BytesIO(request_body), boundary=boundary.encode("utf-8"), @@ -179,7 +179,7 @@ def validate_image_is_image( email_message["Content-Type"] = request_headers["Content-Type"] boundary = email_message.get_boundary() assert isinstance(boundary, str) - parser = TypedMultiPartParser() + parser = MultiPartParser() _, files = parser.parse( stream=io.BytesIO(request_body), boundary=boundary.encode("utf-8"), diff --git a/src/mock_vws/_query_validators/include_target_data_validators.py b/src/mock_vws/_query_validators/include_target_data_validators.py index a539d5a22..176c04ae1 100644 --- a/src/mock_vws/_query_validators/include_target_data_validators.py +++ b/src/mock_vws/_query_validators/include_target_data_validators.py @@ -6,7 +6,8 @@ import logging from email.message import EmailMessage -from mock_vws._query_tools import TypedMultiPartParser +from werkzeug.formparser import MultiPartParser + from mock_vws._query_validators.exceptions import InvalidIncludeTargetData _LOGGER = logging.getLogger(__name__) @@ -32,7 +33,7 @@ def validate_include_target_data( email_message["Content-Type"] = request_headers["Content-Type"] boundary = email_message.get_boundary() assert isinstance(boundary, str) - parser = TypedMultiPartParser() + parser = MultiPartParser() fields, _ = parser.parse( stream=io.BytesIO(request_body), boundary=boundary.encode("utf-8"), diff --git a/src/mock_vws/_query_validators/num_results_validators.py b/src/mock_vws/_query_validators/num_results_validators.py index 37d4cdde9..512201fd9 100644 --- a/src/mock_vws/_query_validators/num_results_validators.py +++ b/src/mock_vws/_query_validators/num_results_validators.py @@ -6,7 +6,8 @@ import logging from email.message import EmailMessage -from mock_vws._query_tools import TypedMultiPartParser +from werkzeug.formparser import MultiPartParser + from mock_vws._query_validators.exceptions import ( InvalidMaxNumResults, MaxNumResultsOutOfRange, @@ -36,7 +37,7 @@ def validate_max_num_results( email_message["Content-Type"] = request_headers["Content-Type"] boundary = email_message.get_boundary() assert isinstance(boundary, str) - parser = TypedMultiPartParser() + parser = MultiPartParser() fields, _ = parser.parse( stream=io.BytesIO(request_body), boundary=boundary.encode("utf-8"),