From dfa7a1e351ee09327f210e312e466344496757d6 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Wed, 22 Apr 2020 16:39:05 +0300 Subject: [PATCH 1/3] fix: broken is null and is not null operator --- superset/connectors/sqla/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 8aac7e091e78..0f237d07b2e1 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -883,9 +883,9 @@ def get_sqla_query( # sqla elif op == utils.FilterOperationType.LIKE.value: where_clause_and.append(col_obj.get_sqla_col().like(eq)) elif op == utils.FilterOperationType.IS_NULL.value: - where_clause_and.append(col_obj.get_sqla_col() is None) + where_clause_and.append(col_obj.get_sqla_col() == None) elif op == utils.FilterOperationType.IS_NOT_NULL.value: - where_clause_and.append(col_obj.get_sqla_col() is None) + where_clause_and.append(col_obj.get_sqla_col() != None) else: raise Exception( _("Invalid filter operation type: %(op)s", op=op) From dc1279fdd743328e0281730e6955811119d66cd4 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Wed, 22 Apr 2020 17:35:09 +0300 Subject: [PATCH 2/3] add unit tests --- tests/sqla_models_tests.py | 40 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 38 insertions(+), 2 deletions(-) diff --git a/tests/sqla_models_tests.py b/tests/sqla_models_tests.py index 700c2e2bb94a..22f4c3962e3c 100644 --- a/tests/sqla_models_tests.py +++ b/tests/sqla_models_tests.py @@ -15,12 +15,12 @@ # specific language governing permissions and limitations # under the License. # isort:skip_file -from typing import Dict +from typing import Any, Dict, NamedTuple, List, Tuple, Union from superset.connectors.sqla.models import SqlaTable, TableColumn from superset.db_engine_specs.druid import DruidEngineSpec from superset.models.core import Database -from superset.utils.core import DbColumnType, get_example_database +from superset.utils.core import DbColumnType, get_example_database, FilterOperationType from .base_tests import SupersetTestCase @@ -109,3 +109,39 @@ def test_has_no_extra_cache_keys(self): extra_cache_keys = table.get_extra_cache_keys(query_obj) self.assertFalse(table.has_calls_to_cache_key_wrapper(query_obj)) self.assertListEqual(extra_cache_keys, []) + + def test_where_operators(self): + class FilterTestCase(NamedTuple): + operator: str + value: Union[float, int, List[Any], str] + expected: str + + filters: Tuple[FilterTestCase, ...] = ( + FilterTestCase(FilterOperationType.IS_NULL, "", "IS NULL"), + FilterTestCase(FilterOperationType.IS_NOT_NULL, "", "IS NOT NULL"), + FilterTestCase(FilterOperationType.GREATER_THAN, 0, "> 0"), + FilterTestCase(FilterOperationType.GREATER_THAN_OR_EQUALS, 0, ">= 0"), + FilterTestCase(FilterOperationType.LESS_THAN, 0, "< 0"), + FilterTestCase(FilterOperationType.LESS_THAN_OR_EQUALS, 0, "<= 0"), + FilterTestCase(FilterOperationType.EQUALS, 0, "= 0"), + FilterTestCase(FilterOperationType.NOT_EQUALS, 0, "!= 0"), + FilterTestCase(FilterOperationType.IN, ["1", "2"], "IN (1, 2)"), + FilterTestCase(FilterOperationType.NOT_IN, ["1", "2"], "NOT IN (1, 2)"), + ) + table = self.get_table_by_name("birth_names") + for filter_ in filters: + query_obj = { + "granularity": None, + "from_dttm": None, + "to_dttm": None, + "groupby": ["gender"], + "metrics": ["count"], + "is_timeseries": False, + "filter": [ + {"col": "num", "op": filter_.operator, "val": filter_.value} + ], + "extras": {}, + } + sqla_query = table.get_sqla_query(**query_obj) + sql = table.database.compile_sqla_query(sqla_query.sqla_query) + self.assertIn(filter_.expected, sql) From fe4838d10d4b5114689432f6e4846dc8143df5ed Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Wed, 22 Apr 2020 17:41:35 +0300 Subject: [PATCH 3/3] Rename filter operator enum --- superset/charts/schemas.py | 2 +- superset/connectors/druid/models.py | 48 ++++++++++++++--------------- superset/connectors/sqla/models.py | 28 ++++++++--------- superset/utils/core.py | 4 +-- tests/sqla_models_tests.py | 22 ++++++------- 5 files changed, 52 insertions(+), 52 deletions(-) diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py index 49d0480224e1..2743732e6e5f 100644 --- a/superset/charts/schemas.py +++ b/superset/charts/schemas.py @@ -362,7 +362,7 @@ class ChartDataFilterSchema(Schema): ) op = fields.String( # pylint: disable=invalid-name description="The comparison operator.", - enum=[filter_op.value for filter_op in utils.FilterOperationType], + enum=[filter_op.value for filter_op in utils.FilterOperator], required=True, example="IN", ) diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py index 8d4aeb1ef0e5..eef20e215c18 100644 --- a/superset/connectors/druid/models.py +++ b/superset/connectors/druid/models.py @@ -84,7 +84,7 @@ from superset.utils.core import ( DimSelector, DTTM_ALIAS, - FilterOperationType, + FilterOperator, flasher, ) except ImportError: @@ -1499,8 +1499,8 @@ def get_filters(cls, raw_filters, num_cols, columns_dict) -> "Filter": eq is None and op not in ( - FilterOperationType.IS_NULL.value, - FilterOperationType.IS_NOT_NULL.value, + FilterOperator.IS_NULL.value, + FilterOperator.IS_NOT_NULL.value, ) ) ): @@ -1517,8 +1517,8 @@ def get_filters(cls, raw_filters, num_cols, columns_dict) -> "Filter": cond = None is_numeric_col = col in num_cols is_list_target = op in ( - FilterOperationType.IN.value, - FilterOperationType.NOT_IN.value, + FilterOperator.IN.value, + FilterOperator.NOT_IN.value, ) eq = cls.filter_values_handler( eq, @@ -1528,11 +1528,11 @@ def get_filters(cls, raw_filters, num_cols, columns_dict) -> "Filter": # For these two ops, could have used Dimension, # but it doesn't support extraction functions - if op == FilterOperationType.EQUALS.value: + if op == FilterOperator.EQUALS.value: cond = Filter( dimension=col, value=eq, extraction_function=extraction_fn ) - elif op == FilterOperationType.NOT_EQUALS.value: + elif op == FilterOperator.NOT_EQUALS.value: cond = ~Filter( dimension=col, value=eq, extraction_function=extraction_fn ) @@ -1557,9 +1557,9 @@ def get_filters(cls, raw_filters, num_cols, columns_dict) -> "Filter": for s in eq: fields.append(Dimension(col) == s) cond = Filter(type="or", fields=fields) - if op == FilterOperationType.NOT_IN.value: + if op == FilterOperator.NOT_IN.value: cond = ~cond - elif op == FilterOperationType.REGEX.value: + elif op == FilterOperator.REGEX.value: cond = Filter( extraction_function=extraction_fn, type="regex", @@ -1569,7 +1569,7 @@ def get_filters(cls, raw_filters, num_cols, columns_dict) -> "Filter": # For the ops below, could have used pydruid's Bound, # but it doesn't support extraction functions - elif op == FilterOperationType.GREATER_THAN_OR_EQUALS.value: + elif op == FilterOperator.GREATER_THAN_OR_EQUALS.value: cond = Bound( extraction_function=extraction_fn, dimension=col, @@ -1579,7 +1579,7 @@ def get_filters(cls, raw_filters, num_cols, columns_dict) -> "Filter": upper=None, ordering=cls._get_ordering(is_numeric_col), ) - elif op == FilterOperationType.LESS_THAN_OR_EQUALS.value: + elif op == FilterOperator.LESS_THAN_OR_EQUALS.value: cond = Bound( extraction_function=extraction_fn, dimension=col, @@ -1589,7 +1589,7 @@ def get_filters(cls, raw_filters, num_cols, columns_dict) -> "Filter": upper=eq, ordering=cls._get_ordering(is_numeric_col), ) - elif op == FilterOperationType.GREATER_THAN.value: + elif op == FilterOperator.GREATER_THAN.value: cond = Bound( extraction_function=extraction_fn, lowerStrict=True, @@ -1599,7 +1599,7 @@ def get_filters(cls, raw_filters, num_cols, columns_dict) -> "Filter": upper=None, ordering=cls._get_ordering(is_numeric_col), ) - elif op == FilterOperationType.LESS_THAN.value: + elif op == FilterOperator.LESS_THAN.value: cond = Bound( extraction_function=extraction_fn, upperStrict=True, @@ -1609,9 +1609,9 @@ def get_filters(cls, raw_filters, num_cols, columns_dict) -> "Filter": upper=eq, ordering=cls._get_ordering(is_numeric_col), ) - elif op == FilterOperationType.IS_NULL.value: + elif op == FilterOperator.IS_NULL.value: cond = Filter(dimension=col, value="") - elif op == FilterOperationType.IS_NOT_NULL.value: + elif op == FilterOperator.IS_NOT_NULL.value: cond = ~Filter(dimension=col, value="") if filters: @@ -1627,14 +1627,14 @@ def _get_ordering(is_numeric_col: bool) -> str: def _get_having_obj(self, col: str, op: str, eq: str) -> "Having": cond = None - if op == FilterOperationType.EQUALS.value: + if op == FilterOperator.EQUALS.value: if col in self.column_names: cond = DimSelector(dimension=col, value=eq) else: cond = Aggregation(col) == eq - elif op == FilterOperationType.GREATER_THAN.value: + elif op == FilterOperator.GREATER_THAN.value: cond = Aggregation(col) > eq - elif op == FilterOperationType.LESS_THAN.value: + elif op == FilterOperator.LESS_THAN.value: cond = Aggregation(col) < eq return cond @@ -1642,9 +1642,9 @@ def _get_having_obj(self, col: str, op: str, eq: str) -> "Having": def get_having_filters(self, raw_filters: List[Dict[str, Any]]) -> "Having": filters = None reversed_op_map = { - FilterOperationType.NOT_EQUALS.value: FilterOperationType.EQUALS.value, - FilterOperationType.GREATER_THAN_OR_EQUALS.value: FilterOperationType.LESS_THAN.value, - FilterOperationType.LESS_THAN_OR_EQUALS.value: FilterOperationType.GREATER_THAN.value, + FilterOperator.NOT_EQUALS.value: FilterOperator.EQUALS.value, + FilterOperator.GREATER_THAN_OR_EQUALS.value: FilterOperator.LESS_THAN.value, + FilterOperator.LESS_THAN_OR_EQUALS.value: FilterOperator.GREATER_THAN.value, } for flt in raw_filters: @@ -1655,9 +1655,9 @@ def get_having_filters(self, raw_filters: List[Dict[str, Any]]) -> "Having": eq = flt["val"] cond = None if op in [ - FilterOperationType.EQUALS.value, - FilterOperationType.GREATER_THAN.value, - FilterOperationType.LESS_THAN.value, + FilterOperator.EQUALS.value, + FilterOperator.GREATER_THAN.value, + FilterOperator.LESS_THAN.value, ]: cond = self._get_having_obj(col, op, eq) elif op in reversed_op_map: diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 0f237d07b2e1..1eed4be2ab86 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -847,8 +847,8 @@ def get_sqla_query( # sqla col_obj = cols.get(col) if col_obj: is_list_target = op in ( - utils.FilterOperationType.IN.value, - utils.FilterOperationType.NOT_IN.value, + utils.FilterOperator.IN.value, + utils.FilterOperator.NOT_IN.value, ) eq = self.filter_values_handler( values=flt.get("val"), @@ -856,35 +856,35 @@ def get_sqla_query( # sqla is_list_target=is_list_target, ) if op in ( - utils.FilterOperationType.IN.value, - utils.FilterOperationType.NOT_IN.value, + utils.FilterOperator.IN.value, + utils.FilterOperator.NOT_IN.value, ): cond = col_obj.get_sqla_col().in_(eq) if isinstance(eq, str) and NULL_STRING in eq: cond = or_(cond, col_obj.get_sqla_col() is None) - if op == utils.FilterOperationType.NOT_IN.value: + if op == utils.FilterOperator.NOT_IN.value: cond = ~cond where_clause_and.append(cond) else: if col_obj.is_numeric: eq = utils.cast_to_num(flt["val"]) - if op == utils.FilterOperationType.EQUALS.value: + if op == utils.FilterOperator.EQUALS.value: where_clause_and.append(col_obj.get_sqla_col() == eq) - elif op == utils.FilterOperationType.NOT_EQUALS.value: + elif op == utils.FilterOperator.NOT_EQUALS.value: where_clause_and.append(col_obj.get_sqla_col() != eq) - elif op == utils.FilterOperationType.GREATER_THAN.value: + elif op == utils.FilterOperator.GREATER_THAN.value: where_clause_and.append(col_obj.get_sqla_col() > eq) - elif op == utils.FilterOperationType.LESS_THAN.value: + elif op == utils.FilterOperator.LESS_THAN.value: where_clause_and.append(col_obj.get_sqla_col() < eq) - elif op == utils.FilterOperationType.GREATER_THAN_OR_EQUALS.value: + elif op == utils.FilterOperator.GREATER_THAN_OR_EQUALS.value: where_clause_and.append(col_obj.get_sqla_col() >= eq) - elif op == utils.FilterOperationType.LESS_THAN_OR_EQUALS.value: + elif op == utils.FilterOperator.LESS_THAN_OR_EQUALS.value: where_clause_and.append(col_obj.get_sqla_col() <= eq) - elif op == utils.FilterOperationType.LIKE.value: + elif op == utils.FilterOperator.LIKE.value: where_clause_and.append(col_obj.get_sqla_col().like(eq)) - elif op == utils.FilterOperationType.IS_NULL.value: + elif op == utils.FilterOperator.IS_NULL.value: where_clause_and.append(col_obj.get_sqla_col() == None) - elif op == utils.FilterOperationType.IS_NOT_NULL.value: + elif op == utils.FilterOperator.IS_NOT_NULL.value: where_clause_and.append(col_obj.get_sqla_col() != None) else: raise Exception( diff --git a/superset/utils/core.py b/superset/utils/core.py index 5749930aee34..3745be8a964e 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -1350,9 +1350,9 @@ class DbColumnType(Enum): TEMPORAL = 2 -class FilterOperationType(str, Enum): +class FilterOperator(str, Enum): """ - Filter operation type + Operators used filter controls """ EQUALS = "==" diff --git a/tests/sqla_models_tests.py b/tests/sqla_models_tests.py index 22f4c3962e3c..af15b266b65d 100644 --- a/tests/sqla_models_tests.py +++ b/tests/sqla_models_tests.py @@ -20,7 +20,7 @@ from superset.connectors.sqla.models import SqlaTable, TableColumn from superset.db_engine_specs.druid import DruidEngineSpec from superset.models.core import Database -from superset.utils.core import DbColumnType, get_example_database, FilterOperationType +from superset.utils.core import DbColumnType, get_example_database, FilterOperator from .base_tests import SupersetTestCase @@ -117,16 +117,16 @@ class FilterTestCase(NamedTuple): expected: str filters: Tuple[FilterTestCase, ...] = ( - FilterTestCase(FilterOperationType.IS_NULL, "", "IS NULL"), - FilterTestCase(FilterOperationType.IS_NOT_NULL, "", "IS NOT NULL"), - FilterTestCase(FilterOperationType.GREATER_THAN, 0, "> 0"), - FilterTestCase(FilterOperationType.GREATER_THAN_OR_EQUALS, 0, ">= 0"), - FilterTestCase(FilterOperationType.LESS_THAN, 0, "< 0"), - FilterTestCase(FilterOperationType.LESS_THAN_OR_EQUALS, 0, "<= 0"), - FilterTestCase(FilterOperationType.EQUALS, 0, "= 0"), - FilterTestCase(FilterOperationType.NOT_EQUALS, 0, "!= 0"), - FilterTestCase(FilterOperationType.IN, ["1", "2"], "IN (1, 2)"), - FilterTestCase(FilterOperationType.NOT_IN, ["1", "2"], "NOT IN (1, 2)"), + FilterTestCase(FilterOperator.IS_NULL, "", "IS NULL"), + FilterTestCase(FilterOperator.IS_NOT_NULL, "", "IS NOT NULL"), + FilterTestCase(FilterOperator.GREATER_THAN, 0, "> 0"), + FilterTestCase(FilterOperator.GREATER_THAN_OR_EQUALS, 0, ">= 0"), + FilterTestCase(FilterOperator.LESS_THAN, 0, "< 0"), + FilterTestCase(FilterOperator.LESS_THAN_OR_EQUALS, 0, "<= 0"), + FilterTestCase(FilterOperator.EQUALS, 0, "= 0"), + FilterTestCase(FilterOperator.NOT_EQUALS, 0, "!= 0"), + FilterTestCase(FilterOperator.IN, ["1", "2"], "IN (1, 2)"), + FilterTestCase(FilterOperator.NOT_IN, ["1", "2"], "NOT IN (1, 2)"), ) table = self.get_table_by_name("birth_names") for filter_ in filters: