diff --git a/.gitignore b/.gitignore index 698f68e..8154e92 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,9 @@ dist/ __pycache__/ *.egg-info *.pyc -test.py \ No newline at end of file +test.py + +venv/ +.venv/ +env/ +.env \ No newline at end of file diff --git a/README.md b/README.md index e36d732..b12f6ff 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,8 @@ def hello( password_expiry: Optional[int] = Json(5), is_admin: bool = Query(False), user_type: UserType = Json(alias="type"), - status: AccountStatus = Json() + status: AccountStatus = Json(), + permissions: dict[str, str] = Query(list_disable_query_csv=True) ): return "Hello World!" @@ -130,7 +131,7 @@ Type Hints allow for inline specification of the input type of a parameter. Some | `datetime.datetime` | Received as a `str` in ISO-8601 date-time format | Y | Y | Y | Y | N | | `datetime.date` | Received as a `str` in ISO-8601 full-date format | Y | Y | Y | Y | N | | `datetime.time` | Received as a `str` in ISO-8601 partial-time format | Y | Y | Y | Y | N | -| `dict` | For `Query` and `Form` inputs, users should pass the stringified JSON | N | Y | Y | Y | N | +| `dict` | For `Query` and `Form` inputs, users should pass the stringified JSON. For `Query`, you likely will need to use `list_disable_query_csv=True`. | N | Y | Y | Y | N | | `FileStorage` | | N | N | N | N | Y | | A subclass of `StrEnum` or `IntEnum`, or a subclass of `Enum` with `str` or `int` mixins prior to Python 3.11 | | Y | Y | Y | Y | N | | `uuid.UUID` | Received as a `str` with or without hyphens, case-insensitive | Y | Y | Y | Y | N | diff --git a/flask_parameter_validation/docs_blueprint.py b/flask_parameter_validation/docs_blueprint.py index 671e251..8712b29 100644 --- a/flask_parameter_validation/docs_blueprint.py +++ b/flask_parameter_validation/docs_blueprint.py @@ -1,9 +1,12 @@ +import sys from enum import Enum import flask from flask import Blueprint, current_app, jsonify - from flask_parameter_validation import ValidateParameters +if sys.version_info >= (3, 10): + from types import UnionType + docs_blueprint = Blueprint( "docs", __name__, url_prefix="/docs", template_folder="./templates" ) @@ -76,7 +79,10 @@ def get_arg_type_hint(fdocs, arg_name): """ arg_type = fdocs["argspec"].annotations[arg_name] def recursively_resolve_type_hint(type_to_resolve): - if hasattr(type_to_resolve, "__name__"): # In Python 3.9, Optional and Union do not have __name__ + if sys.version_info >= (3, 10) and isinstance(type_to_resolve, UnionType): + # support 3.10 style unions (e.g. str | int) + type_base_name = "Union" + elif hasattr(type_to_resolve, "__name__"): # In Python 3.9, Optional and Union do not have __name__ type_base_name = type_to_resolve.__name__ elif hasattr(type_to_resolve, "_name") and type_to_resolve._name is not None: # In Python 3.9, _name exists on list[whatever] and has a non-None value diff --git a/flask_parameter_validation/parameter_types/__init__.py b/flask_parameter_validation/parameter_types/__init__.py index 935aad7..d9910c3 100644 --- a/flask_parameter_validation/parameter_types/__init__.py +++ b/flask_parameter_validation/parameter_types/__init__.py @@ -4,7 +4,8 @@ from .query import Query from .route import Route from .multi_source import MultiSource +from .parameter import Parameter __all__ = [ - "File", "Form", "Json", "Query", "Route", "MultiSource" + "File", "Form", "Json", "Query", "Route", "MultiSource", "Parameter" ] diff --git a/flask_parameter_validation/parameter_validation.py b/flask_parameter_validation/parameter_validation.py index 66ba68e..3bed8b1 100644 --- a/flask_parameter_validation/parameter_validation.py +++ b/flask_parameter_validation/parameter_validation.py @@ -1,23 +1,29 @@ +import json +import sys import asyncio import functools import inspect import re import uuid from inspect import signature -from typing import Optional +from typing import Optional, Union, get_origin, get_args, Any import flask -from flask import request, Response +from flask import request from werkzeug.datastructures import ImmutableMultiDict from werkzeug.exceptions import BadRequest from .exceptions import (InvalidParameterTypeError, MissingInputError, ValidationError) -from .parameter_types import File, Form, Json, Query, Route +from .parameter_types import File, Form, Json, Query, Route, Parameter from .parameter_types.multi_source import MultiSource fn_list = dict() -list_type_hints = ["typing.List", "typing.Optional[typing.List", "list", "typing.Optional[list"] +# from 3.10 onwards, Unions written X | Y have the type UnionType +UNION_TYPES = [Union] +if sys.version_info >= (3, 10): + from types import UnionType + UNION_TYPES = [Union, UnionType] class ValidateParameters: @classmethod @@ -72,23 +78,22 @@ def nested_func_helper(**kwargs): except BadRequest: return {"error": ({"error": "Could not parse JSON."}, 400), "validated": False} - # Step 3 - Extract list of parameters expected to be lists (otherwise all values are converted to lists), and for Query params, whether they should split strings by `,` - expected_list_params = {} + # Step 3 - For Query params, find which parameters should be split by commas + split_csv = {} default_list_disable_query_csv = flask.current_app.config.get("FPV_LIST_DISABLE_QUERY_CSV", False) for name, param in expected_inputs.items(): - if any([str(param.annotation).startswith(list_hint) for list_hint in list_type_hints]): - list_disable_query_csv = default_list_disable_query_csv - if param.default.list_disable_query_csv is not None: - list_disable_query_csv = param.default.list_disable_query_csv - expected_list_params[param.default.alias or name] = not list_disable_query_csv + list_disable_query_csv = default_list_disable_query_csv + if param.default.list_disable_query_csv is not None: + list_disable_query_csv = param.default.list_disable_query_csv + split_csv[param.default.alias or name] = not list_disable_query_csv # Step 4 - Convert request inputs to dicts request_inputs = { Route: kwargs.copy(), Json: json_input or {}, - Query: self._to_dict_with_lists(request.args, list(expected_list_params.keys()), list(expected_list_params.values())), - Form: self._to_dict_with_lists(request.form, list(expected_list_params.keys())), - File: self._to_dict_with_lists(request.files, list(expected_list_params.keys())), + Query: self._to_dict_with_lists(request.args, split_csv), + Form: self._to_dict_with_lists(request.form), + File: self._to_dict_with_lists(request.files), } # Step 5 - Validate each expected input @@ -129,90 +134,140 @@ def nested_func(**kwargs): return nested_func def _to_dict_with_lists( - self, multi_dict: ImmutableMultiDict, expected_lists: list[str], split_strings: Optional[list[bool]] = None + self, multi_dict: ImmutableMultiDict, split_csv: Optional[dict[str, bool]] = None ) -> dict: dict_with_lists = {} for key, values in multi_dict.lists(): - # Only create lists for keys that are expected to be lists - if key in expected_lists: - key_index = expected_lists.index(key) - list_values = [] - for value in values: - if value != "" or len(values) > 1: - if split_strings and split_strings[key_index]: - list_values.extend(value.split(",")) - else: - list_values.append(value) - dict_with_lists[key] = list_values - else: - # If only one value and not expected to be a list, don't use a list - dict_with_lists[key] = values[0] if len(values) == 1 else values + list_values = [] + for value in values: + if split_csv and key in split_csv and split_csv[key]: + list_values.extend(value.split(",")) + else: + list_values.append(value) + dict_with_lists[key] = list_values[0] if len(list_values) == 1 else list_values return dict_with_lists - def _generic_types_validation_helper(self, expected_name, expected_input_type, expected_input_type_str, user_input, source): + def _generic_types_validation_helper(self, + expected_name: str, + expected_input_type: type, + user_input: Any, + source: Parameter, + other_union_allowed_types: list[type] = []) -> tuple[Any, bool]: """ Perform recursive validation of generic types (Optional, Union, and List/list) + and convert input. If input is invalid, a fully converted input is not garunteed. + + :param expected_name: the name of the parameter we are checking against + :param expected_input_type: the type annotation of the parameter + :param user_input: the API user's input + :param source: the type of Parameter we are taking input from + :param other_union_allowed_types: the other types that are unioned at this level. + We check one type at a time, but the convert() method needs to know + what else the user_input is allowed to be to convert properly. + + :return: tuple of format (converted user_input, validation_success) """ - # In python3.7+, typing.Optional is used instead of typing.Union[..., None] - if expected_input_type_str.startswith("typing.Optional"): - sub_expected_input_types = expected_input_type - sub_expected_input_type_str = expected_input_type_str.replace("typing.Optional[", "typing.Union[None, ") - user_inputs, sub_expected_input_types = self._generic_types_validation_helper(expected_name, sub_expected_input_types, sub_expected_input_type_str, user_input, source) - elif expected_input_type_str.startswith("typing.Union"): - if type(expected_input_type) is tuple or type(expected_input_type) is list: - sub_expected_input_types = expected_input_type - else: - sub_expected_input_types = expected_input_type.__args__ - sub_expected_input_type_str = expected_input_type_str[expected_input_type_str.index("[") + 1:-1] - if type(user_input) is list: - user_inputs = user_input - else: - user_inputs = [user_input] - user_inputs, sub_expected_input_types = self._generic_types_validation_helper(expected_name, sub_expected_input_types, sub_expected_input_type_str, user_inputs, source) - # If typing.List in optional and user supplied valid list, convert remaining check only for list - for exp_type in sub_expected_input_types: - if any(str(exp_type).startswith(list_hint) for list_hint in list_type_hints): - if type(user_input) is list: - if hasattr(exp_type, "__args__"): - sub_expected_input_types = exp_type.__args__ - if len(sub_expected_input_types) == 1: - sub_expected_input_types = sub_expected_input_types[0] - sub_expected_input_type_str = str(sub_expected_input_types) - user_inputs = user_input - user_inputs, sub_expected_input_types = self._generic_types_validation_helper(expected_name, sub_expected_input_types, sub_expected_input_type_str, user_inputs, source) - # If list, expand inner typing items. Otherwise, convert to list to match anyway. - elif any(expected_input_type_str.startswith(list_hint) for list_hint in list_type_hints): - if hasattr(expected_input_type, "__args__"): - sub_expected_input_types = expected_input_type.__args__[0] + # union + if get_origin(expected_input_type) in UNION_TYPES: + # check for unions (Optional is just a Union with None) + sub_expected_input_types = expected_input_type.__args__ + # go through each type in the union and see if we get a match + for sub_expected_input_type in sub_expected_input_types: + sub_converted_input, sub_success = self._generic_types_validation_helper(expected_name, sub_expected_input_type, user_input, source, other_union_allowed_types=list(sub_expected_input_types)) + if sub_success: + return sub_converted_input, True + return user_input, False + + # list + elif get_origin(expected_input_type) is list or expected_input_type is list: + if type(user_input) is not list: + # check if we should try to work with strings + if type(source) is not Form and type(source) is not Query: + return user_input, False + # if using a source that supports multidict style lists, + # give singletons the benefit of the doubt. they could still count + # as single-element lists + if type(user_input) is str and len(user_input) > 0: + try: + user_input = json.loads(user_input) + # check for a stringified list e.g. '[1, 2]' + if type(user_input) is not list: + user_input = [user_input] + except ValueError: + user_input = [user_input] + else: + user_input = [user_input] + + # process + if len(get_args(expected_input_type)) == 0: + # expected type is just a bare list with no sub type + # we set to Any instead of returning True so that the input can get converted + sub_expected_input_type = Any else: - sub_expected_input_types = expected_input_type - sub_expected_input_type_str = expected_input_type_str[expected_input_type_str.index("[")+1:-1] - if type(user_input) is list: - user_inputs = user_input + sub_expected_input_type = get_args(expected_input_type)[0] + if len(user_input) == 1 and user_input[0] == "": + # treat arrays of a single empty string as an empty array to support the Query param &value= + return [], True + converted_list = [] + # go through and validate each item in the array + for inp in user_input: + sub_converted_input, sub_success = self._generic_types_validation_helper(expected_name, sub_expected_input_type, inp, source) + if not sub_success: + return user_input, False + converted_list.append(sub_converted_input) + return converted_list, True + + # dict + elif get_origin(expected_input_type) is dict or expected_input_type is dict: + # check for a stringified dict (like from Query or Form) + if type(user_input) is str and len(user_input) > 0: + try: + user_input = json.loads(user_input) + except ValueError: + return user_input, False + # check for a normal dict + if type(user_input) is not dict: + return user_input, False + + # process + if len(get_args(expected_input_type)) == 0: + # expected type is just a bare dict with no sub types + # we set to Any instead of returning True so that the input can get converted + key_expected_input_type = Any + val_expected_input_type = Any else: - user_inputs = [user_input] - user_inputs, sub_expected_input_types = self._generic_types_validation_helper(expected_name, sub_expected_input_types, sub_expected_input_type_str, user_inputs, source) + key_expected_input_type = get_args(expected_input_type)[0] + val_expected_input_type = get_args(expected_input_type)[1] + converted_dict = {} + # go through and validate each key and value in the dict + for key, val in user_input.items(): + key_converted_input, key_success = self._generic_types_validation_helper(expected_name, key_expected_input_type, key, source) + val_converted_input, val_success = self._generic_types_validation_helper(expected_name, val_expected_input_type, val, source) + if not key_success or not val_success: + return user_input, False + converted_dict[key_converted_input] = val_converted_input + return converted_dict, True + + # non-generics else: - if type(user_input) is list: - user_inputs = user_input - else: - user_inputs = [user_input] - if type(expected_input_type) is list or type(expected_input_type) is tuple: - sub_expected_input_types = expected_input_type - elif type(expected_input_type) is list and len(expected_input_type) > 0 and hasattr(expected_input_type[0], "__len__"): - sub_expected_input_types = expected_input_type[0] - elif expected_input_type is list and not hasattr(expected_input_type, "__args__"): - return [user_inputs], [expected_input_type] - else: - sub_expected_input_types = [expected_input_type] - for count, value in enumerate(user_inputs): - try: - user_inputs[count] = source.convert( - value, sub_expected_input_types - ) - except ValueError as e: - raise ValidationError(str(e), expected_name, expected_input_type) - return user_inputs, sub_expected_input_types + if expected_input_type is Any: + return user_input, True + + try: + # convert + user_input = source.convert( + # include any other allowed types for proper conversion + user_input, [expected_input_type] + other_union_allowed_types + ) + + if expected_input_type is Any: + # Any should always return true, no matter the input + return user_input, True + + # the actual "primative" type check + return user_input, type(user_input) is expected_input_type + except ValueError as e: + raise ValidationError(str(e), expected_name, expected_input_type) def validate(self, expected_input, all_request_inputs): """ @@ -227,14 +282,10 @@ def validate(self, expected_input, all_request_inputs): expected_name = expected_delivery_type.alias else: expected_name = expected_input.name - # Get input type as string to recognize typing objects, e.g. to convert typing.List to "typing.List" - # Note: We use this str() method, as typing API is too unreliable, see https://stackoverflow.com/a/52664522/7173479 - expected_input_type_str = str(expected_input.annotation) - # original_expected_input_type and expected_input_type_str will mutate throughout program, + # original_expected_input_type will mutate throughout program, # so we need to keep the original for error messages original_expected_input_type = expected_input.annotation - original_expected_input_type_str = expected_input_type_str # Expected delivery types can be a list if using MultiSource expected_delivery_types = [expected_delivery_type] @@ -273,46 +324,21 @@ def validate(self, expected_input, all_request_inputs): expected_name, source.__class__ ) - # Skip validation if typing.Any is given - if expected_input_type_str.startswith("typing.Any"): - return user_input + converted_user_input, validation_success = self._generic_types_validation_helper(expected_name, expected_input_type, user_input, source) - user_inputs, expected_input_types = self._generic_types_validation_helper(expected_name, expected_input_type, expected_input_type_str, user_input, source) - - # Validate that user type(s) match expected type(s) - validation_success = all( - type(inp) in expected_input_types for inp in user_inputs - ) - - # Validate that if lists are required, lists are given - if any(expected_input_type_str.startswith(list_hint) for list_hint in list_type_hints): - if type(user_input) is not list: - validation_success = False + # Validate parameter-specific requirements are met + try: + source.validate(converted_user_input) + except ValueError as e: + raise ValidationError(str(e), expected_name, expected_input_type) # Error if types don't match if not validation_success: - if hasattr( - original_expected_input_type, "__name__" - ) and not (original_expected_input_type_str.startswith("typing.") or original_expected_input_type_str.startswith("list")): - type_name = original_expected_input_type.__name__ - else: - type_name = original_expected_input_type_str + type_name = str(original_expected_input_type) raise ValidationError( f"must be type '{type_name}'", expected_name, original_expected_input_type, ) - # Validate parameter-specific requirements are met - try: - if type(user_input) is list: - source.validate(user_input) - else: - source.validate(user_inputs[0]) - except ValueError as e: - raise ValidationError(str(e), expected_name, expected_input_type) - - # Return input back to parent function - if any(expected_input_type_str.startswith(list_hint) for list_hint in list_type_hints): - return user_inputs - return user_inputs[0] + return converted_user_input diff --git a/flask_parameter_validation/test/test_form_params.py b/flask_parameter_validation/test/test_form_params.py index 810e611..4cf6a66 100644 --- a/flask_parameter_validation/test/test_form_params.py +++ b/flask_parameter_validation/test/test_form_params.py @@ -1,4 +1,5 @@ # String Validation +import sys import datetime import json import uuid @@ -1592,3 +1593,152 @@ def test_uuid_func(client): # Test that input failing func yields error r = client.post(url, data={"v": "492c6dfc-1730-11f0-9cd2-0242ac120002"}) assert "error" in r.json + + +def test_dict_args_str_str(client): + url = "/form/dict/args/str/str" + # Test that correct input yields input value + d = {"hi": "ho"} + r = client.post(url, data={"v": json.dumps(d)}) + assert "v" in r.json + assert r.json["v"] == d + d = {"hi": -45} + # Test that incorrect input yields error + r = client.post(url, data={"v": json.dumps(d)}) + assert "error" in r.json + + +def test_dict_args_str_union(client): + url = "/form/dict/args/str/union" + # Test that union input yields input value + d = {"hi": "ho", "id": 1} + r = client.post(url, data={"v": json.dumps(d)}) + assert "v" in r.json + assert r.json["v"] == d + # Test that only one type also yields input value + d = {"hi": 90, "id": 1} + r = client.post(url, data={"v": json.dumps(d)}) + assert "v" in r.json + assert r.json["v"] == d + # Test that empty dict yields input value + d = {} + r = client.post(url, data={"v": json.dumps(d)}) + assert "v" in r.json + assert r.json["v"] == d + + +def test_dict_args_str_list(client): + url = "/form/dict/args/str/list" + # Test that correct input yields input value + d = {"1.3": False, "9.0": [2, 4, 5]} + r = client.post(url, data={"v": json.dumps(d)}) + assert "v" in r.json + assert r.json["v"] == d + # Test that empty dict yields input value + d = {} + r = client.post(url, data={"v": json.dumps(d)}) + assert "v" in r.json + assert r.json["v"] == d + # Test that incorrect values yields error + d = {"test": False, "ing": [2, True, 5]} + r = client.post(url, data={"v": json.dumps(d)}) + assert "error" in r.json + + +def test_list_dict_args_str_union(client): + url = "/form/list/dict/args/str/union" + # Test that correct input yields input value + d = [{"id": 3, "chicken": "noodle soup"}, {}, {"foo": "bar"}] + r = client.post(url, data={"v": json.dumps(d)}) + assert "v" in r.json + assert r.json["v"] == d + # Test that empty list yields input value + d = [] + r = client.post(url, data={"v": json.dumps(d)}) + assert "v" in r.json + assert r.json["v"] == d + # Test that incorrect values yields error + d = [{"id": 1.03, "name": "foo"}, {"id": -1}] + r = client.post(url, data={"v": json.dumps(d)}) + assert "error" in r.json + + + +if sys.version_info >= (3, 10): + def test_union_requred_3_10(client): + url = "/form/union/3_10/required" + # Test that missing input yields error + r = client.post(url) + assert "error" in r.json + # Test that present datetime input yields input value + d = datetime.datetime.now() + r = client.post(url, data={"v": d}) + assert "v" in r.json + assert r.json["v"] == d.isoformat() + # Test that present bool input yields input value + d = True + r = client.post(url, data={"v": d}) + assert "v" in r.json + assert r.json["v"] == d + d = {"v": "string"} + # Test that present non-bool/datetime input yields error + r = client.post(url, data={"v": d}) + assert "error" in r.json + + def test_union_optional_3_10(client): + url = "/form/union/3_10/optional" + # Test that missing input yields input value + r = client.post(url) + assert "v" in r.json + assert r.json["v"] is None + # Test that present datetime input yields input value + d = datetime.datetime.now() + r = client.post(url, data={"v": d}) + assert "v" in r.json + assert r.json["v"] == d.isoformat() + # Test that present bool input yields input value + d = True + r = client.post(url, data={"v": d}) + assert "v" in r.json + assert r.json["v"] == d + d = "string" + # Test that present non-bool/datetime input yields error + r = client.post(url, data={"v": d}) + assert "error" in r.json + + def test_dict_args_str_3_10_union(client): + url = "/form/dict/args/str/3_10_union" + # Test that union input yields input value + d = {"hi": "ho", "id": 1} + r = client.post(url, data={"v": json.dumps(d)}) + assert "v" in r.json + assert r.json["v"] == d + # Test that only one type also yields input value + d = {"hi": 90, "id": 1} + r = client.post(url, data={"v": json.dumps(d)}) + assert "v" in r.json + assert r.json["v"] == d + # Test that empty dict yields input value + d = {} + r = client.post(url, data={"v": json.dumps(d)}) + assert "v" in r.json + assert r.json["v"] == d + + def test_dict_args_str_list_3_10_union(client): + url = "/form/dict/args/str/list/3_10_union" + # Test that correct input yields input value + d = {"1.3": False, "9.0": [2, 4, 5]} + r = client.post(url, data={"v": json.dumps(d)}) + assert "v" in r.json + assert r.json["v"] == d + # Test that empty dict yields input value + d = {} + r = client.post(url, data={"v": json.dumps(d)}) + assert "v" in r.json + assert r.json["v"] == d + # Test that incorrect values yields error + d = {"test": False, "ing": [2, True, 5]} + r = client.post(url, data={"v": json.dumps(d)}) + assert "error" in r.json + + diff --git a/flask_parameter_validation/test/test_json_params.py b/flask_parameter_validation/test/test_json_params.py index e585534..0131035 100644 --- a/flask_parameter_validation/test/test_json_params.py +++ b/flask_parameter_validation/test/test_json_params.py @@ -1,4 +1,5 @@ # String Validation +import sys import datetime import uuid from typing import Type, List, Optional @@ -1839,3 +1840,149 @@ def test_uuid_func(client): # Test that input failing func yields error r = client.post(url, json={"v": "492c6dfc-1730-11f0-9cd2-0242ac120002"}) assert "error" in r.json + + +def test_dict_args_str_str(client): + url = "/json/dict/args/str/str" + # Test that correct input yields input value + d = {"hi": "ho"} + r = client.post(url, json={"v": d}) + assert "v" in r.json + assert r.json["v"] == d + d = {"hi": -45} + # Test that incorrect input yields error + r = client.post(url, json={"v": d}) + assert "error" in r.json + + +def test_dict_args_str_union(client): + url = "/json/dict/args/str/union" + # Test that union input yields input value + d = {"hi": "ho", "id": 1} + r = client.post(url, json={"v": d}) + assert "v" in r.json + assert r.json["v"] == d + # Test that only one type also yields input value + d = {"hi": 90, "id": 1} + r = client.post(url, json={"v": d}) + assert "v" in r.json + assert r.json["v"] == d + # Test that empty dict yields input value + d = {} + r = client.post(url, json={"v": d}) + assert "v" in r.json + assert r.json["v"] == d + + +def test_dict_args_str_list(client): + url = "/json/dict/args/str/list" + # Test that correct input yields input value + d = {"1.3": False, "9.0": [2, 4, 5]} + r = client.post(url, json={"v": d}) + assert "v" in r.json + assert r.json["v"] == d + # Test that empty dict yields input value + d = {} + r = client.post(url, json={"v": d}) + assert "v" in r.json + assert r.json["v"] == d + # Test that incorrect values yields error + d = {"test": False, "ing": [2, True, 5]} + r = client.post(url, json={"v": d}) + assert "error" in r.json + +def test_list_dict_args_str_union(client): + url = "/json/list/dict/args/str/union" + # Test that correct input yields input value + d = [{"id": 3, "chicken": "noodle soup"}, {}, {"foo": "bar"}] + r = client.post(url, json={"v": d}) + assert "v" in r.json + assert r.json["v"] == d + # Test that empty list yields input value + d = [] + r = client.post(url, json={"v": d}) + assert "v" in r.json + assert r.json["v"] == d + # Test that incorrect values yields error + d = [{"id": 1.03, "name": "foo"}, {"id": -1}] + r = client.post(url, json={"v": d}) + + +if sys.version_info >= (3, 10): + def test_union_requred_3_10(client): + url = "/json/union/3_10/required" + # Test that missing input yields error + r = client.post(url) + assert "error" in r.json + # Test that present datetime input yields input value + d = datetime.datetime.now() + r = client.post(url, json={"v": d.isoformat()}) + assert "v" in r.json + assert r.json["v"] == d.isoformat() + # Test that present bool input yields input value + d = True + r = client.post(url, json={"v": d}) + assert "v" in r.json + assert r.json["v"] == d + d = {"v": "string"} + # Test that present non-bool/datetime input yields error + r = client.post(url, json={"v": d}) + assert "error" in r.json + + def test_union_optional_3_10(client): + url = "/json/union/3_10/optional" + # Test that missing input yields input value + r = client.post(url) + assert "v" in r.json + assert r.json["v"] is None + # Test that present datetime input yields input value + d = datetime.datetime.now() + r = client.post(url, json={"v": d.isoformat()}) + assert "v" in r.json + assert r.json["v"] == d.isoformat() + # Test that present bool input yields input value + d = True + r = client.post(url, json={"v": d}) + assert "v" in r.json + assert r.json["v"] == d + d = "string" + # Test that present non-bool/datetime input yields error + r = client.post(url, json={"v": d}) + assert "error" in r.json + + def test_dict_args_str_3_10_union(client): + url = "/json/dict/args/str/3_10_union" + # Test that union input yields input value + d = {"hi": "ho", "id": 1} + r = client.post(url, json={"v": d}) + assert "v" in r.json + assert r.json["v"] == d + # Test that only one type also yields input value + d = {"hi": 90, "id": 1} + r = client.post(url, json={"v": d}) + assert "v" in r.json + assert r.json["v"] == d + # Test that empty dict yields input value + d = {} + r = client.post(url, json={"v": d}) + assert "v" in r.json + assert r.json["v"] == d + + def test_dict_args_str_list_3_10_union(client): + url = "/json/dict/args/str/list/3_10_union" + # Test that correct input yields input value + d = {"1.3": False, "9.0": [2, 4, 5]} + r = client.post(url, json={"v": d}) + assert "v" in r.json + assert r.json["v"] == d + # Test that empty dict yields input value + d = {} + r = client.post(url, json={"v": d}) + assert "v" in r.json + assert r.json["v"] == d + # Test that incorrect values yields error + d = {"test": False, "ing": [2, True, 5]} + r = client.post(url, json={"v": d}) + assert "error" in r.json + + diff --git a/flask_parameter_validation/test/test_multi_source_params.py b/flask_parameter_validation/test/test_multi_source_params.py index 440da46..bc4ff41 100644 --- a/flask_parameter_validation/test/test_multi_source_params.py +++ b/flask_parameter_validation/test/test_multi_source_params.py @@ -1,3 +1,4 @@ +import sys import datetime import json @@ -199,6 +200,72 @@ def test_multi_source_optional_dict(client, source_a, source_b): r = client.get(url) assert r.json["v"] is None +@pytest.mark.parametrize(*common_parameters) +def test_multi_source_dict_args_str_str(client, source_a, source_b): + if source_a == source_b or "route" in [source_a, source_b]: # Duplicate sources shouldn't be something someone does, so we won't test for it, Route does not support parameters of type 'dict' + return + d = {"c": "d", "e": "f"} + url = f"/ms_{source_a}_{source_b}/dict/args/str/str" + for source in [source_a, source_b]: + # Test that present input yields input value + r = None + if source == "query": + r = client.get(url, query_string={"v": json.dumps(d)}) + elif source == "form": + r = client.get(url, data={"v": json.dumps(d)}) + elif source == "json": + r = client.get(url, json={"v": d}) + assert r is not None + assert "v" in r.json + assert r.json["v"] == d + # Test that missing input yields error + r = client.get(url) + assert "error" in r.json + +@pytest.mark.parametrize(*common_parameters) +def test_multi_source_dict_args_str_union(client, source_a, source_b): + if source_a == source_b or "route" in [source_a, source_b]: # Duplicate sources shouldn't be something someone does, so we won't test for it, Route does not support parameters of type 'dict' + return + d = {"c": "d", "e": -3} + url = f"/ms_{source_a}_{source_b}/dict/args/str/union" + for source in [source_a, source_b]: + # Test that present input yields input value + r = None + if source == "query": + r = client.get(url, query_string={"v": json.dumps(d)}) + elif source == "form": + r = client.get(url, data={"v": json.dumps(d)}) + elif source == "json": + r = client.get(url, json={"v": d}) + assert r is not None + assert "v" in r.json + assert r.json["v"] == d + # Test that missing input yields error + r = client.get(url) + assert "error" in r.json + +@pytest.mark.parametrize(*common_parameters) +def test_multi_source_dict_args_str_list(client, source_a, source_b): + if source_a == source_b or "route" in [source_a, source_b]: # Duplicate sources shouldn't be something someone does, so we won't test for it, Route does not support parameters of type 'dict' + return + d = {"c": True, "e": True, "b": [3, 4, 87]} + url = f"/ms_{source_a}_{source_b}/dict/args/str/list" + for source in [source_a, source_b]: + # Test that present input yields input value + r = None + if source == "query": + r = client.get(url, query_string={"v": json.dumps(d)}) + elif source == "form": + r = client.get(url, data={"v": json.dumps(d)}) + elif source == "json": + r = client.get(url, json={"v": d}) + assert r is not None + assert "v" in r.json + assert r.json["v"] == d + # Test that missing input yields error + r = client.get(url) + assert "error" in r.json + @pytest.mark.parametrize(*common_parameters) def test_multi_source_float(client, source_a, source_b): if source_a == source_b: # This shouldn't be something someone does, so we won't test for it @@ -343,6 +410,29 @@ def test_multi_source_optional_list(client, source_a, source_b): assert r.json["v"] is None +@pytest.mark.parametrize(*common_parameters) +def test_multi_source_list_dict(client, source_a, source_b): + if source_a == source_b or "route" in [source_a, source_b]: # Duplicate sources shouldn't be something someone does, so we won't test for it, Route does not support parameters of type 'List' + return + l = [{"id": 3, "chicken": "noodle soup"}, {}, {"foo": "bar"}] + url = f"/ms_{source_a}_{source_b}/list/dict/args/str/union" + for source in [source_a, source_b]: + # Test that present input yields input value + r = None + if source == "query": + r = client.get(url, query_string={"v": json.dumps(l)}) + elif source == "form": + r = client.get(url, data={"v": json.dumps(l)}) + elif source == "json": + r = client.get(url, json={"v": l}) + assert r is not None + assert "v" in r.json + assert r.json["v"] == l + + # Test that missing input yields error + r = client.get(url) + assert "error" in r.json + @pytest.mark.parametrize(*common_parameters) def test_multi_source_str(client, source_a, source_b): if source_a == source_b: # This shouldn't be something someone does, so we won't test for it @@ -599,4 +689,83 @@ def test_multi_source_optional_uuid(client, source_a, source_b): assert r.json["v"] == "28124cee-c074-448d-be63-6490ff5c89c0" # Test that missing input yields error r = client.get(url) - assert r.json["v"] is None \ No newline at end of file + assert r.json["v"] is None + +if sys.version_info >= (3, 10): + @pytest.mark.parametrize(*common_parameters) + def test_multi_source_3_10_union(client, source_a, source_b): + if source_a == source_b or "route" in [source_a, source_b]: # Duplicate sources shouldn't be something someone does, so we won't test for it, Route does not support parameters of type 'dict' + return + url = f"/ms_{source_a}_{source_b}/union/3_10/required" + for source in [source_a, source_b]: + # Test that present input yields input value + r = None + d = False + if source == "query": + r = client.get(url, query_string={"v": d}) + elif source == "form": + r = client.get(url, data={"v": d}) + elif source == "json": + r = client.get(url, json={"v": d}) + assert r is not None + assert "v" in r.json + assert r.json["v"] == d + r = None + d = datetime.datetime.now().isoformat() + if source == "query": + r = client.get(url, query_string={"v": d}) + elif source == "form": + r = client.get(url, data={"v": d}) + elif source == "json": + r = client.get(url, json={"v": d}) + assert r is not None + assert "v" in r.json + assert r.json["v"] == d + # Test that missing input yields error + r = client.get(url) + assert "error" in r.json + + @pytest.mark.parametrize(*common_parameters) + def test_multi_source_dict_args_str_3_10_union(client, source_a, source_b): + if source_a == source_b or "route" in [source_a, source_b]: # Duplicate sources shouldn't be something someone does, so we won't test for it, Route does not support parameters of type 'dict' + return + d = {"c": "d", "e": -3} + url = f"/ms_{source_a}_{source_b}/dict/args/str/3_10_union" + for source in [source_a, source_b]: + # Test that present input yields input value + r = None + if source == "query": + r = client.get(url, query_string={"v": json.dumps(d)}) + elif source == "form": + r = client.get(url, data={"v": json.dumps(d)}) + elif source == "json": + r = client.get(url, json={"v": d}) + assert r is not None + assert "v" in r.json + assert r.json["v"] == d + # Test that missing input yields error + r = client.get(url) + assert "error" in r.json + + @pytest.mark.parametrize(*common_parameters) + def test_multi_source_dict_args_str_list_3_10_union(client, source_a, source_b): + if source_a == source_b or "route" in [source_a, source_b]: # Duplicate sources shouldn't be something someone does, so we won't test for it, Route does not support parameters of type 'dict' + return + d = {"c": True, "e": True, "b": [3, 4, 87]} + url = f"/ms_{source_a}_{source_b}/dict/args/str/list/3_10_union" + for source in [source_a, source_b]: + # Test that present input yields input value + r = None + if source == "query": + r = client.get(url, query_string={"v": json.dumps(d)}) + elif source == "form": + r = client.get(url, data={"v": json.dumps(d)}) + elif source == "json": + r = client.get(url, json={"v": d}) + assert r is not None + assert "v" in r.json + assert r.json["v"] == d + # Test that missing input yields error + r = client.get(url) + assert "error" in r.json + diff --git a/flask_parameter_validation/test/test_query_params.py b/flask_parameter_validation/test/test_query_params.py index 9c73f3d..7c0b3eb 100644 --- a/flask_parameter_validation/test/test_query_params.py +++ b/flask_parameter_validation/test/test_query_params.py @@ -1,4 +1,5 @@ # String Validation +import sys import datetime import json import uuid @@ -2703,3 +2704,150 @@ def test_uuid_func(client): # Test that input failing func yields error r = client.get(url, query_string={"v": "492c6dfc-1730-11f0-9cd2-0242ac120002"}) assert "error" in r.json + + +def test_dict_args_str_str(client): + url = "/query/dict/args/str/str" + # Test that correct input yields input value + d = {"hi": "ho"} + r = client.get(url, query_string={"v": json.dumps(d)}) + assert "v" in r.json + assert r.json["v"] == d + d = {"hi": -45} + # Test that incorrect input yields error + r = client.get(url, query_string={"v": json.dumps(d)}) + assert "error" in r.json + + +def test_dict_args_str_union(client): + url = "/query/dict/args/str/union" + # Test that union input yields input value + d = {"hi": "ho", "id": 1} + r = client.get(url, query_string={"v": json.dumps(d)}) + assert "v" in r.json + assert r.json["v"] == d + # Test that only one type also yields input value + d = {"hi": 90, "id": 1} + r = client.get(url, query_string={"v": json.dumps(d)}) + assert "v" in r.json + assert r.json["v"] == d + # Test that empty dict yields input value + d = {} + r = client.get(url, query_string={"v": json.dumps(d)}) + assert "v" in r.json + assert r.json["v"] == d + + +def test_dict_args_str_list(client): + url = "/query/dict/args/str/list" + # Test that correct input yields input value + d = {"1.3": False, "9.0": [2, 4, 5]} + r = client.get(url, query_string={"v": json.dumps(d)}) + assert "v" in r.json + assert r.json["v"] == d + # Test that empty dict yields input value + d = {} + r = client.get(url, query_string={"v": json.dumps(d)}) + assert "v" in r.json + assert r.json["v"] == d + # Test that incorrect values yields error + d = {"test": False, "ing": [2, True, 5]} + r = client.get(url, query_string={"v": json.dumps(d)}) + assert "error" in r.json + +def test_list_dict_args_str_union(client): + url = "/query/list/dict/args/str/union" + # Test that correct input yields input value + d = [{"id": 3, "chicken": "noodle soup"}, {}, {"foo": "bar"}] + r = client.get(url, query_string={"v": json.dumps(d)}) + assert "v" in r.json + assert r.json["v"] == d + # Test that empty list yields input value + d = [] + r = client.get(url, query_string={"v": json.dumps(d)}) + assert "v" in r.json + assert r.json["v"] == d + # Test that incorrect values yields error + d = [{"id": 1.03, "name": "foo"}, {"id": -1}] + r = client.get(url, query_string={"v": json.dumps(d)}) + assert "error" in r.json + + +if sys.version_info >= (3, 10): + def test_union_requred_3_10(client): + url = "/query/union/3_10/required" + # Test that missing input yields error + r = client.get(url) + assert "error" in r.json + # Test that present datetime input yields input value + d = datetime.datetime.now() + r = client.get(url, query_string={"v": d}) + assert "v" in r.json + assert r.json["v"] == d.isoformat() + # Test that present bool input yields input value + d = True + r = client.get(url, query_string={"v": d}) + assert "v" in r.json + assert r.json["v"] == d + d = {"v": "string"} + # Test that present non-bool/datetime input yields error + r = client.get(url, query_string={"v": d}) + assert "error" in r.json + + def test_union_optional_3_10(client): + url = "/query/union/3_10/optional" + # Test that missing input yields input value + r = client.get(url) + assert "v" in r.json + assert r.json["v"] is None + # Test that present datetime input yields input value + d = datetime.datetime.now() + r = client.get(url, query_string={"v": d}) + assert "v" in r.json + assert r.json["v"] == d.isoformat() + # Test that present bool input yields input value + d = True + r = client.get(url, query_string={"v": d}) + assert "v" in r.json + assert r.json["v"] == d + d = "string" + # Test that present non-bool/datetime input yields error + r = client.get(url, query_string={"v": d}) + assert "error" in r.json + + def test_dict_args_str_3_10_union(client): + url = "/query/dict/args/str/3_10_union" + # Test that union input yields input value + d = {"hi": "ho", "id": 1} + r = client.get(url, query_string={"v": json.dumps(d)}) + assert "v" in r.json + assert r.json["v"] == d + # Test that only one type also yields input value + d = {"hi": 90, "id": 1} + r = client.get(url, query_string={"v": json.dumps(d)}) + assert "v" in r.json + assert r.json["v"] == d + # Test that empty dict yields input value + d = {} + r = client.get(url, query_string={"v": json.dumps(d)}) + assert "v" in r.json + assert r.json["v"] == d + + def test_dict_args_str_list_3_10_union(client): + url = "/query/dict/args/str/list/3_10_union" + # Test that correct input yields input value + d = {"1.3": False, "9.0": [2, 4, 5]} + r = client.get(url, query_string={"v": json.dumps(d)}) + assert "v" in r.json + assert r.json["v"] == d + # Test that empty dict yields input value + d = {} + r = client.get(url, query_string={"v": json.dumps(d)}) + assert "v" in r.json + assert r.json["v"] == d + # Test that incorrect values yields error + d = {"test": False, "ing": [2, True, 5]} + r = client.get(url, query_string={"v": json.dumps(d)}) + assert "error" in r.json + + diff --git a/flask_parameter_validation/test/test_route_params.py b/flask_parameter_validation/test/test_route_params.py index 8ef22a1..a99c261 100644 --- a/flask_parameter_validation/test/test_route_params.py +++ b/flask_parameter_validation/test/test_route_params.py @@ -1,4 +1,5 @@ # String Validation +import sys import datetime import uuid from typing import Type, List, Optional @@ -476,3 +477,22 @@ def test_uuid_func(client): # Test that input failing func yields error r = client.get(f"{url}/492c6dfc-1730-11f0-9cd2-0242ac120002") assert "error" in r.json + +if sys.version_info >= (3, 10): + def test_union_requred_3_10(client): + url = "/route/union/3_10/required" + # Test that present datetime input yields input value + d = datetime.datetime.now() + r = client.get(f"{url}/{d.isoformat()}") + assert "v" in r.json + assert r.json["v"] == d.isoformat() + # Test that present bool input yields input value + d = True + r = client.get(f"{url}/{d}") + assert "v" in r.json + assert r.json["v"] == d + d = {"v": "string"} + # Test that present non-bool/datetime input yields error + r = client.get(f"{url}/{d}") + assert "error" in r.json + diff --git a/flask_parameter_validation/test/testing_blueprints/dict_blueprint.py b/flask_parameter_validation/test/testing_blueprints/dict_blueprint.py index f914a91..481ca81 100644 --- a/flask_parameter_validation/test/testing_blueprints/dict_blueprint.py +++ b/flask_parameter_validation/test/testing_blueprints/dict_blueprint.py @@ -1,4 +1,5 @@ import datetime +import sys from typing import Optional, List, Union from flask import Blueprint, jsonify @@ -93,4 +94,56 @@ def func(v: dict = ParamType(func=are_keys_lowercase)): def json_schema(v: dict = ParamType(json_schema=json_schema)): return jsonify({"v": v}) + @decorator("/args/str/str") + @ValidateParameters() + def args_str_str(v: dict[str, str] = ParamType(list_disable_query_csv=True)): + assert type(v) is dict + for key, val in v.items(): + assert type(key) is str + assert type(val) is str + return jsonify({"v": v}) + + @decorator("/args/str/union") + @ValidateParameters() + def args_str_union(v: dict[str, Union[str,int]] = ParamType(list_disable_query_csv=True)): + assert type(v) is dict + for key, val in v.items(): + assert type(key) is str + assert type(val) is str or type(val) is int + return jsonify({"v": v}) + + @decorator("/args/str/list") + @ValidateParameters() + def args_str_list(v: dict[str, Union[list[int], bool]] = ParamType(list_disable_query_csv=True)): + assert type(v) is dict + for key, val in v.items(): + assert type(key) is str + assert type(val) is list or type(val) is bool + if type(val) is list: + for item in val: + assert type(item) is int + return jsonify({"v": v}) + + if sys.version_info >= (3, 10): + @decorator("/args/str/3_10_union") + @ValidateParameters() + def args_str_3_10_union(v: dict[str, str|int] = ParamType(list_disable_query_csv=True)): + assert type(v) is dict + for key, val in v.items(): + assert type(key) is str + assert type(val) is str or type(val) is int + return jsonify({"v": v}) + + @decorator("/args/str/list/3_10_union") + @ValidateParameters() + def args_str_list_3_10_union(v: dict[str, list[int] | bool] = ParamType(list_disable_query_csv=True)): + assert type(v) is dict + for key, val in v.items(): + assert type(key) is str + assert type(val) is list or type(val) is bool + if type(val) is list: + for item in val: + assert type(item) is int + return jsonify({"v": v}) + return dict_bp diff --git a/flask_parameter_validation/test/testing_blueprints/list_blueprint.py b/flask_parameter_validation/test/testing_blueprints/list_blueprint.py index 540be22..4367068 100644 --- a/flask_parameter_validation/test/testing_blueprints/list_blueprint.py +++ b/flask_parameter_validation/test/testing_blueprints/list_blueprint.py @@ -373,4 +373,15 @@ def non_typing(v: list[str] = ParamType()): def optional_non_typing(v: Optional[list[str]] = ParamType()): return jsonify({"v": v}) + @decorator("/dict/args/str/union") + @ValidateParameters() + def dict_args_str_union(v: list[dict[str, Union[str, int]]] = ParamType(list_disable_query_csv=True)): + assert type(v) is list + for ele in v: + assert type(ele) is dict + for key, val in ele.items(): + assert type(key) is str + assert type(val) is str or type(val) is int + return jsonify({"v": v}) + return list_bp diff --git a/flask_parameter_validation/test/testing_blueprints/multi_source_blueprint.py b/flask_parameter_validation/test/testing_blueprints/multi_source_blueprint.py index a9f1b01..d57dfc4 100644 --- a/flask_parameter_validation/test/testing_blueprints/multi_source_blueprint.py +++ b/flask_parameter_validation/test/testing_blueprints/multi_source_blueprint.py @@ -1,3 +1,4 @@ +import sys import datetime import uuid from typing import Optional, List, Union @@ -89,7 +90,7 @@ def multi_source_int(v: int = MultiSource(sources[0], sources[1])): def multi_source_optional_int(v: Optional[int] = MultiSource(sources[0], sources[1])): return jsonify({"v": v}) - # Only List[int] is tested here - the other existing tests for lists should be exhaustive enough to catch issues + # Only List[int] and list[dict[str, Union[str, int]]] is tested here - the other existing tests for lists should be exhaustive enough to catch issues @param_bp.route("/required_list", methods=["GET", "POST"]) # Route doesn't support List parameters @ValidateParameters() @@ -99,6 +100,51 @@ def multi_source_list(v: List[int] = MultiSource(sources[0], sources[1])): assert type(v[0]) is int return jsonify({"v": v}) + @param_bp.route("/dict/args/str/str", methods=["GET", "POST"]) + # Route doesn't support List parameters + @ValidateParameters() + def multi_source_dict_str_str(v: dict[str, str] = MultiSource(sources[0], sources[1], list_disable_query_csv=True)): + assert type(v) is dict + for key, val in v.items(): + assert type(key) is str + assert type(val) is str + return jsonify({"v": v}) + + @param_bp.route("/dict/args/str/union", methods=["GET", "POST"]) + # Route doesn't support List parameters + @ValidateParameters() + def multi_source_dict_str_union(v: dict[str, Union[str, int]] = MultiSource(sources[0], sources[1], list_disable_query_csv=True)): + assert type(v) is dict + for key, val in v.items(): + assert type(key) is str + assert type(val) is str or type(val) is int + return jsonify({"v": v}) + + @param_bp.route("/dict/args/str/list", methods=["GET", "POST"]) + # Route doesn't support List parameters + @ValidateParameters() + def multi_source_dict_str_list(v: dict[str, Union[list[int], bool]] = MultiSource(sources[0], sources[1], list_disable_query_csv=True)): + assert type(v) is dict + for key, val in v.items(): + assert type(key) is str + assert type(val) is list or type(val) is bool + if type(val) is list: + for ele in val: + assert type(ele) is int + return jsonify({"v": v}) + + @param_bp.route("/list/dict/args/str/union", methods=["GET", "POST"]) + # Route doesn't support List parameters + @ValidateParameters() + def multi_source_list_dict_str_union(v: list[dict[str, Union[str, int]]] = MultiSource(sources[0], sources[1], list_disable_query_csv=True)): + assert type(v) is list + for ele in v: + assert type(ele) is dict + for key, val in ele.items(): + assert type(key) is str + assert type(val) is str or type(val) is int + return jsonify({"v": v}) + @param_bp.route("/optional_list", methods=["GET", "POST"]) # Route doesn't support List parameters @ValidateParameters() @@ -162,4 +208,35 @@ def multi_source_uuid(v: uuid.UUID = MultiSource(sources[0], sources[1])): def multi_source_optional_uuid(v: Optional[uuid.UUID] = MultiSource(sources[0], sources[1])): return jsonify({"v": v}) + if sys.version_info >= (3, 10): + @param_bp.route("/union/3_10/required", methods=["GET", "POST"]) + @param_bp.route("/union/3_10/required/", methods=["GET", "POST"]) + @ValidateParameters() + def multi_source_3_10_union(v: bool | datetime.datetime = MultiSource(sources[0], sources[1])): + return jsonify({"v": v.isoformat() if type(v) is datetime.datetime else v}) + + @param_bp.route("/dict/args/str/3_10_union", methods=["GET", "POST"]) + # Route doesn't support Dict parameters + @ValidateParameters() + def multi_source_dict_str_3_10_union(v: dict[str, Union[str, int]] = MultiSource(sources[0], sources[1], list_disable_query_csv=True)): + assert type(v) is dict + for key, val in v.items(): + assert type(key) is str + assert type(val) is str or type(val) is int + return jsonify({"v": v}) + + @param_bp.route("/dict/args/str/list/3_10_union", methods=["GET", "POST"]) + # Route doesn't support Dict parameters + @ValidateParameters() + def multi_source_dict_str_list_3_10_union(v: dict[str, Union[list[int], bool]] = MultiSource(sources[0], sources[1], list_disable_query_csv=True)): + assert type(v) is dict + for key, val in v.items(): + assert type(key) is str + assert type(val) is list or type(val) is bool + if type(val) is list: + for ele in val: + assert type(ele) is int + return jsonify({"v": v}) + + return param_bp diff --git a/flask_parameter_validation/test/testing_blueprints/union_blueprint.py b/flask_parameter_validation/test/testing_blueprints/union_blueprint.py index 8f2bd3e..61cc9dc 100644 --- a/flask_parameter_validation/test/testing_blueprints/union_blueprint.py +++ b/flask_parameter_validation/test/testing_blueprints/union_blueprint.py @@ -1,4 +1,5 @@ import datetime +import sys from typing import Optional, Union from flask import Blueprint, jsonify @@ -62,4 +63,17 @@ def is_truthy(v): def func(v: Union[bool, int] = ParamType(func=is_truthy)): return jsonify({"v": v}) - return union_bp \ No newline at end of file + if sys.version_info >= (3, 10): + @decorator(path("/3_10/required", "/")) + @ValidateParameters() + def required_3_10(v: bool | datetime.datetime = ParamType()): + assert type(v) is bool or type(v) is datetime.datetime + return jsonify({"v": v.isoformat() if type(v) is datetime.datetime else v}) + + @decorator("/3_10/optional") # Route not supported by Optional + @ValidateParameters() + def optional_3_10(v: Optional[bool | datetime.datetime] = ParamType()): + assert type(v) is bool or type(v) is datetime.datetime or v is None + return jsonify({"v": v.isoformat() if type(v) is datetime.datetime else v}) + + return union_bp