Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: broken IS NULL and IS NOT NULL operator #9613

Merged
merged 3 commits into from Apr 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion superset/charts/schemas.py
Expand Up @@ -362,7 +362,7 @@ class ChartDataFilterSchema(Schema):
)
op = fields.String( # pylint: disable=invalid-name
description="The comparison operator.",
enum=[filter_op.value for filter_op in utils.FilterOperationType],
enum=[filter_op.value for filter_op in utils.FilterOperator],
required=True,
example="IN",
)
Expand Down
48 changes: 24 additions & 24 deletions superset/connectors/druid/models.py
Expand Up @@ -84,7 +84,7 @@
from superset.utils.core import (
DimSelector,
DTTM_ALIAS,
FilterOperationType,
FilterOperator,
flasher,
)
except ImportError:
Expand Down Expand Up @@ -1499,8 +1499,8 @@ def get_filters(cls, raw_filters, num_cols, columns_dict) -> "Filter":
eq is None
and op
not in (
FilterOperationType.IS_NULL.value,
FilterOperationType.IS_NOT_NULL.value,
FilterOperator.IS_NULL.value,
FilterOperator.IS_NOT_NULL.value,
)
)
):
Expand All @@ -1517,8 +1517,8 @@ def get_filters(cls, raw_filters, num_cols, columns_dict) -> "Filter":
cond = None
is_numeric_col = col in num_cols
is_list_target = op in (
FilterOperationType.IN.value,
FilterOperationType.NOT_IN.value,
FilterOperator.IN.value,
FilterOperator.NOT_IN.value,
)
eq = cls.filter_values_handler(
eq,
Expand All @@ -1528,11 +1528,11 @@ def get_filters(cls, raw_filters, num_cols, columns_dict) -> "Filter":

# For these two ops, could have used Dimension,
# but it doesn't support extraction functions
if op == FilterOperationType.EQUALS.value:
if op == FilterOperator.EQUALS.value:
cond = Filter(
dimension=col, value=eq, extraction_function=extraction_fn
)
elif op == FilterOperationType.NOT_EQUALS.value:
elif op == FilterOperator.NOT_EQUALS.value:
cond = ~Filter(
dimension=col, value=eq, extraction_function=extraction_fn
)
Expand All @@ -1557,9 +1557,9 @@ def get_filters(cls, raw_filters, num_cols, columns_dict) -> "Filter":
for s in eq:
fields.append(Dimension(col) == s)
cond = Filter(type="or", fields=fields)
if op == FilterOperationType.NOT_IN.value:
if op == FilterOperator.NOT_IN.value:
cond = ~cond
elif op == FilterOperationType.REGEX.value:
elif op == FilterOperator.REGEX.value:
cond = Filter(
extraction_function=extraction_fn,
type="regex",
Expand All @@ -1569,7 +1569,7 @@ def get_filters(cls, raw_filters, num_cols, columns_dict) -> "Filter":

# For the ops below, could have used pydruid's Bound,
# but it doesn't support extraction functions
elif op == FilterOperationType.GREATER_THAN_OR_EQUALS.value:
elif op == FilterOperator.GREATER_THAN_OR_EQUALS.value:
cond = Bound(
extraction_function=extraction_fn,
dimension=col,
Expand All @@ -1579,7 +1579,7 @@ def get_filters(cls, raw_filters, num_cols, columns_dict) -> "Filter":
upper=None,
ordering=cls._get_ordering(is_numeric_col),
)
elif op == FilterOperationType.LESS_THAN_OR_EQUALS.value:
elif op == FilterOperator.LESS_THAN_OR_EQUALS.value:
cond = Bound(
extraction_function=extraction_fn,
dimension=col,
Expand All @@ -1589,7 +1589,7 @@ def get_filters(cls, raw_filters, num_cols, columns_dict) -> "Filter":
upper=eq,
ordering=cls._get_ordering(is_numeric_col),
)
elif op == FilterOperationType.GREATER_THAN.value:
elif op == FilterOperator.GREATER_THAN.value:
cond = Bound(
extraction_function=extraction_fn,
lowerStrict=True,
Expand All @@ -1599,7 +1599,7 @@ def get_filters(cls, raw_filters, num_cols, columns_dict) -> "Filter":
upper=None,
ordering=cls._get_ordering(is_numeric_col),
)
elif op == FilterOperationType.LESS_THAN.value:
elif op == FilterOperator.LESS_THAN.value:
cond = Bound(
extraction_function=extraction_fn,
upperStrict=True,
Expand All @@ -1609,9 +1609,9 @@ def get_filters(cls, raw_filters, num_cols, columns_dict) -> "Filter":
upper=eq,
ordering=cls._get_ordering(is_numeric_col),
)
elif op == FilterOperationType.IS_NULL.value:
elif op == FilterOperator.IS_NULL.value:
cond = Filter(dimension=col, value="")
elif op == FilterOperationType.IS_NOT_NULL.value:
elif op == FilterOperator.IS_NOT_NULL.value:
cond = ~Filter(dimension=col, value="")

if filters:
Expand All @@ -1627,24 +1627,24 @@ def _get_ordering(is_numeric_col: bool) -> str:

def _get_having_obj(self, col: str, op: str, eq: str) -> "Having":
cond = None
if op == FilterOperationType.EQUALS.value:
if op == FilterOperator.EQUALS.value:
if col in self.column_names:
cond = DimSelector(dimension=col, value=eq)
else:
cond = Aggregation(col) == eq
elif op == FilterOperationType.GREATER_THAN.value:
elif op == FilterOperator.GREATER_THAN.value:
cond = Aggregation(col) > eq
elif op == FilterOperationType.LESS_THAN.value:
elif op == FilterOperator.LESS_THAN.value:
cond = Aggregation(col) < eq

return cond

def get_having_filters(self, raw_filters: List[Dict[str, Any]]) -> "Having":
filters = None
reversed_op_map = {
FilterOperationType.NOT_EQUALS.value: FilterOperationType.EQUALS.value,
FilterOperationType.GREATER_THAN_OR_EQUALS.value: FilterOperationType.LESS_THAN.value,
FilterOperationType.LESS_THAN_OR_EQUALS.value: FilterOperationType.GREATER_THAN.value,
FilterOperator.NOT_EQUALS.value: FilterOperator.EQUALS.value,
FilterOperator.GREATER_THAN_OR_EQUALS.value: FilterOperator.LESS_THAN.value,
FilterOperator.LESS_THAN_OR_EQUALS.value: FilterOperator.GREATER_THAN.value,
}

for flt in raw_filters:
Expand All @@ -1655,9 +1655,9 @@ def get_having_filters(self, raw_filters: List[Dict[str, Any]]) -> "Having":
eq = flt["val"]
cond = None
if op in [
FilterOperationType.EQUALS.value,
FilterOperationType.GREATER_THAN.value,
FilterOperationType.LESS_THAN.value,
FilterOperator.EQUALS.value,
FilterOperator.GREATER_THAN.value,
FilterOperator.LESS_THAN.value,
]:
cond = self._get_having_obj(col, op, eq)
elif op in reversed_op_map:
Expand Down
32 changes: 16 additions & 16 deletions superset/connectors/sqla/models.py
Expand Up @@ -847,45 +847,45 @@ def get_sqla_query( # sqla
col_obj = cols.get(col)
if col_obj:
is_list_target = op in (
utils.FilterOperationType.IN.value,
utils.FilterOperationType.NOT_IN.value,
utils.FilterOperator.IN.value,
utils.FilterOperator.NOT_IN.value,
)
eq = self.filter_values_handler(
values=flt.get("val"),
target_column_is_numeric=col_obj.is_numeric,
is_list_target=is_list_target,
)
if op in (
utils.FilterOperationType.IN.value,
utils.FilterOperationType.NOT_IN.value,
utils.FilterOperator.IN.value,
utils.FilterOperator.NOT_IN.value,
):
cond = col_obj.get_sqla_col().in_(eq)
if isinstance(eq, str) and NULL_STRING in eq:
cond = or_(cond, col_obj.get_sqla_col() is None)
if op == utils.FilterOperationType.NOT_IN.value:
if op == utils.FilterOperator.NOT_IN.value:
cond = ~cond
where_clause_and.append(cond)
else:
if col_obj.is_numeric:
eq = utils.cast_to_num(flt["val"])
if op == utils.FilterOperationType.EQUALS.value:
if op == utils.FilterOperator.EQUALS.value:
where_clause_and.append(col_obj.get_sqla_col() == eq)
elif op == utils.FilterOperationType.NOT_EQUALS.value:
elif op == utils.FilterOperator.NOT_EQUALS.value:
where_clause_and.append(col_obj.get_sqla_col() != eq)
elif op == utils.FilterOperationType.GREATER_THAN.value:
elif op == utils.FilterOperator.GREATER_THAN.value:
where_clause_and.append(col_obj.get_sqla_col() > eq)
elif op == utils.FilterOperationType.LESS_THAN.value:
elif op == utils.FilterOperator.LESS_THAN.value:
where_clause_and.append(col_obj.get_sqla_col() < eq)
elif op == utils.FilterOperationType.GREATER_THAN_OR_EQUALS.value:
elif op == utils.FilterOperator.GREATER_THAN_OR_EQUALS.value:
where_clause_and.append(col_obj.get_sqla_col() >= eq)
elif op == utils.FilterOperationType.LESS_THAN_OR_EQUALS.value:
elif op == utils.FilterOperator.LESS_THAN_OR_EQUALS.value:
where_clause_and.append(col_obj.get_sqla_col() <= eq)
elif op == utils.FilterOperationType.LIKE.value:
elif op == utils.FilterOperator.LIKE.value:
where_clause_and.append(col_obj.get_sqla_col().like(eq))
elif op == utils.FilterOperationType.IS_NULL.value:
where_clause_and.append(col_obj.get_sqla_col() is None)
elif op == utils.FilterOperationType.IS_NOT_NULL.value:
where_clause_and.append(col_obj.get_sqla_col() is None)
elif op == utils.FilterOperator.IS_NULL.value:
where_clause_and.append(col_obj.get_sqla_col() == None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is None and is not None is preferred

Copy link
Member Author

@villebro villebro Apr 22, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It actually doesn't work in this case (tested), SqlAlchemy expects them to be in this form.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Going to add unit tests.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yes I understand!

elif op == utils.FilterOperator.IS_NOT_NULL.value:
where_clause_and.append(col_obj.get_sqla_col() != None)
else:
raise Exception(
_("Invalid filter operation type: %(op)s", op=op)
Expand Down
4 changes: 2 additions & 2 deletions superset/utils/core.py
Expand Up @@ -1350,9 +1350,9 @@ class DbColumnType(Enum):
TEMPORAL = 2


class FilterOperationType(str, Enum):
class FilterOperator(str, Enum):
"""
Filter operation type
Operators used filter controls
"""

EQUALS = "=="
Expand Down
40 changes: 38 additions & 2 deletions tests/sqla_models_tests.py
Expand Up @@ -15,12 +15,12 @@
# specific language governing permissions and limitations
# under the License.
# isort:skip_file
from typing import Dict
from typing import Any, Dict, NamedTuple, List, Tuple, Union

from superset.connectors.sqla.models import SqlaTable, TableColumn
from superset.db_engine_specs.druid import DruidEngineSpec
from superset.models.core import Database
from superset.utils.core import DbColumnType, get_example_database
from superset.utils.core import DbColumnType, get_example_database, FilterOperator

from .base_tests import SupersetTestCase

Expand Down Expand Up @@ -109,3 +109,39 @@ def test_has_no_extra_cache_keys(self):
extra_cache_keys = table.get_extra_cache_keys(query_obj)
self.assertFalse(table.has_calls_to_cache_key_wrapper(query_obj))
self.assertListEqual(extra_cache_keys, [])

def test_where_operators(self):
class FilterTestCase(NamedTuple):
operator: str
value: Union[float, int, List[Any], str]
expected: str

filters: Tuple[FilterTestCase, ...] = (
FilterTestCase(FilterOperator.IS_NULL, "", "IS NULL"),
FilterTestCase(FilterOperator.IS_NOT_NULL, "", "IS NOT NULL"),
FilterTestCase(FilterOperator.GREATER_THAN, 0, "> 0"),
FilterTestCase(FilterOperator.GREATER_THAN_OR_EQUALS, 0, ">= 0"),
FilterTestCase(FilterOperator.LESS_THAN, 0, "< 0"),
FilterTestCase(FilterOperator.LESS_THAN_OR_EQUALS, 0, "<= 0"),
FilterTestCase(FilterOperator.EQUALS, 0, "= 0"),
FilterTestCase(FilterOperator.NOT_EQUALS, 0, "!= 0"),
FilterTestCase(FilterOperator.IN, ["1", "2"], "IN (1, 2)"),
FilterTestCase(FilterOperator.NOT_IN, ["1", "2"], "NOT IN (1, 2)"),
)
table = self.get_table_by_name("birth_names")
for filter_ in filters:
query_obj = {
"granularity": None,
"from_dttm": None,
"to_dttm": None,
"groupby": ["gender"],
"metrics": ["count"],
"is_timeseries": False,
"filter": [
{"col": "num", "op": filter_.operator, "val": filter_.value}
],
"extras": {},
}
sqla_query = table.get_sqla_query(**query_obj)
sql = table.database.compile_sqla_query(sqla_query.sqla_query)
self.assertIn(filter_.expected, sql)