From 634226ae395b97e25b14d384891e0b761d9ff22d Mon Sep 17 00:00:00 2001 From: Ville Brofeldt <33317356+villebro@users.noreply.github.com> Date: Mon, 13 Jul 2020 17:21:02 +0300 Subject: [PATCH] fix(chart-data-api): case insensitive evaluation of filter op (#10299) * fix(chart-data-api): case insensitive evaluation of filter op * fix(chart-data-api): case insensitive evaluation of filter op * mypy * remove print statement * add test --- superset/charts/schemas.py | 21 +++++---------- superset/utils/schema.py | 54 ++++++++++++++++++++++++++++++++++++++ tests/charts/api_tests.py | 14 ++++++++++ tests/utils_tests.py | 21 +++++++++++++++ 4 files changed, 96 insertions(+), 14 deletions(-) create mode 100644 superset/utils/schema.py diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py index 891dde727280..9e5e66334211 100644 --- a/superset/charts/schemas.py +++ b/superset/charts/schemas.py @@ -14,15 +14,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, Union +from typing import Any, Dict from flask_babel import gettext as _ -from marshmallow import fields, post_load, Schema, validate, ValidationError +from marshmallow import fields, post_load, Schema, validate from marshmallow.validate import Length, Range from superset.common.query_context import QueryContext -from superset.exceptions import SupersetException -from superset.utils import core as utils +from superset.utils import schema as utils +from superset.utils.core import FilterOperator # # RISON/JSON schemas for query parameters @@ -101,13 +101,6 @@ } -def validate_json(value: Union[bytes, bytearray, str]) -> None: - try: - utils.validate_json(value) - except SupersetException: - raise ValidationError("JSON not valid") - - class ChartPostSchema(Schema): """ Schema to add a new chart. @@ -124,7 +117,7 @@ class ChartPostSchema(Schema): ) owners = fields.List(fields.Integer(description=owners_description)) params = fields.String( - description=params_description, allow_none=True, validate=validate_json + description=params_description, allow_none=True, validate=utils.validate_json ) cache_timeout = fields.Integer( description=cache_timeout_description, allow_none=True @@ -573,8 +566,8 @@ class ChartDataFilterSchema(Schema): ) op = fields.String( # pylint: disable=invalid-name description="The comparison operator.", - validate=validate.OneOf( - choices=[filter_op.value for filter_op in utils.FilterOperator] + validate=utils.OneOfCaseInsensitive( + choices=[filter_op.value for filter_op in FilterOperator] ), required=True, example="IN", diff --git a/superset/utils/schema.py b/superset/utils/schema.py new file mode 100644 index 000000000000..2384e2851419 --- /dev/null +++ b/superset/utils/schema.py @@ -0,0 +1,54 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Any, Union + +from marshmallow import validate, ValidationError + +from superset.exceptions import SupersetException +from superset.utils import core as utils + + +class OneOfCaseInsensitive(validate.OneOf): + """ + Marshmallow validator that's based on the built-in `OneOf`, but performs + validation case insensitively. + """ + + def __call__(self, value: Any) -> str: + try: + if (value.lower() if isinstance(value, str) else value) not in [ + choice.lower() if isinstance(choice, str) else choice + for choice in self.choices + ]: + raise ValidationError(self._format_error(value)) + except TypeError as error: + raise ValidationError(self._format_error(value)) from error + + return value + + +def validate_json(value: Union[bytes, bytearray, str]) -> None: + """ + JSON Validator that can be passed to a Marshmallow `Field`'s validate argument. + + :raises ValidationError: if value is not serializable to JSON + :param value: an object that should be parseable to JSON + """ + try: + utils.validate_json(value) + except SupersetException: + raise ValidationError("JSON not valid") diff --git a/tests/charts/api_tests.py b/tests/charts/api_tests.py index c3e3b50effb1..78461cf78a93 100644 --- a/tests/charts/api_tests.py +++ b/tests/charts/api_tests.py @@ -708,6 +708,20 @@ def test_chart_data_default_sample_limit(self): result = response_payload["result"][0] self.assertEqual(result["rowcount"], 5) + def test_chart_data_mixed_case_filter_op(self): + """ + Chart data API: Ensure mixed case filter operator generates valid result + """ + self.login(username="admin") + table = self.get_table_by_name("birth_names") + request_payload = get_query_context(table.name, table.id, table.type) + request_payload["queries"][0]["filters"][0]["op"] = "In" + request_payload["queries"][0]["row_limit"] = 10 + rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") + response_payload = json.loads(rv.data.decode("utf-8")) + result = response_payload["result"][0] + self.assertEqual(result["rowcount"], 10) + def test_chart_data_with_invalid_datasource(self): """Chart data API: Test chart data query with invalid schema """ diff --git a/tests/utils_tests.py b/tests/utils_tests.py index 25ad02f33581..4c2092ae8501 100644 --- a/tests/utils_tests.py +++ b/tests/utils_tests.py @@ -27,6 +27,7 @@ import numpy from flask import Flask, g from flask_caching import Cache +import marshmallow from sqlalchemy.exc import ArgumentError import tests.test_app @@ -60,6 +61,7 @@ zlib_compress, zlib_decompress, ) +from superset.utils import schema from superset.views.utils import ( build_extra_filters, get_form_data, @@ -582,6 +584,8 @@ def test_json_encoded_obj(self): self.assertEqual(jsonObj.process_result_value(val, "dialect"), obj) def test_validate_json(self): + valid = '{"a": 5, "b": [1, 5, ["g", "h"]]}' + self.assertIsNone(validate_json(valid)) invalid = '{"a": 5, "b": [1, 5, ["g", "h]]}' with self.assertRaises(SupersetException): validate_json(invalid) @@ -1344,3 +1348,20 @@ def test_log_this(self) -> None: json.loads(record.json)["form_data"]["viz_type"], slc.viz.form_data["viz_type"], ) + + def test_schema_validate_json(self): + valid = '{"a": 5, "b": [1, 5, ["g", "h"]]}' + self.assertIsNone(schema.validate_json(valid)) + invalid = '{"a": 5, "b": [1, 5, ["g", "h]]}' + self.assertRaises(marshmallow.ValidationError, schema.validate_json, invalid) + + def test_schema_one_of_case_insensitive(self): + validator = schema.OneOfCaseInsensitive(choices=[1, 2, 3, "FoO", "BAR", "baz"]) + self.assertEqual(1, validator(1)) + self.assertEqual(2, validator(2)) + self.assertEqual("FoO", validator("FoO")) + self.assertEqual("FOO", validator("FOO")) + self.assertEqual("bar", validator("bar")) + self.assertEqual("BaZ", validator("BaZ")) + self.assertRaises(marshmallow.ValidationError, validator, "qwerty") + self.assertRaises(marshmallow.ValidationError, validator, 4)