Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support None operand in EQUAL operator #21713

Merged
merged 2 commits into from
Oct 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion superset/charts/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,7 +819,8 @@ class ChartDataFilterSchema(Schema):
)
val = fields.Raw(
description="The value or values to compare against. Can be a string, "
"integer, decimal or list, depending on the operator.",
"integer, decimal, None or list, depending on the operator.",
allow_none=True,
example=["China", "France", "Japan"],
)
grain = fields.String(
Expand Down
9 changes: 8 additions & 1 deletion superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1615,7 +1615,14 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
elif op == utils.FilterOperator.IS_FALSE.value:
where_clause_and.append(sqla_col.is_(False))
else:
if eq is None:
if (
op
not in {
utils.FilterOperator.EQUALS.value,
utils.FilterOperator.NOT_EQUALS.value,
}
and eq is None
):
raise QueryObjectValidationError(
_(
"Must specify a value for filters "
Expand Down
120 changes: 80 additions & 40 deletions tests/integration_tests/sqla_models_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from tests.integration_tests.test_app import app

from .base_tests import SupersetTestCase
from .conftest import only_postgresql

VIRTUAL_TABLE_INT_TYPES: Dict[str, Pattern[str]] = {
"hive": re.compile(r"^INT_TYPE$"),
Expand Down Expand Up @@ -659,51 +660,90 @@ def test_filter_on_text_column(text_column_table):
assert result_object.df["count"][0] == 1


def test_should_generate_closed_and_open_time_filter_range():
with app.app_context():
if backend() != "postgresql":
pytest.skip(f"{backend()} has different dialect for datetime column")

table = SqlaTable(
table_name="temporal_column_table",
sql=(
"SELECT '2021-12-31'::timestamp as datetime_col "
"UNION SELECT '2022-01-01'::timestamp "
"UNION SELECT '2022-03-10'::timestamp "
"UNION SELECT '2023-01-01'::timestamp "
"UNION SELECT '2023-03-10'::timestamp "
),
database=get_example_database(),
)
TableColumn(
column_name="datetime_col",
type="TIMESTAMP",
table=table,
is_dttm=True,
)
SqlMetric(metric_name="count", expression="count(*)", table=table)
result_object = table.query(
@only_postgresql
def test_should_generate_closed_and_open_time_filter_range(login_as_admin):
Comment on lines -685 to +664
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bycatch: use only_postgresql decorator to narrow the test case, and use login_as_admin to define the Flask app context, there's no logical change.

table = SqlaTable(
table_name="temporal_column_table",
sql=(
"SELECT '2021-12-31'::timestamp as datetime_col "
"UNION SELECT '2022-01-01'::timestamp "
"UNION SELECT '2022-03-10'::timestamp "
"UNION SELECT '2023-01-01'::timestamp "
"UNION SELECT '2023-03-10'::timestamp "
),
database=get_example_database(),
)
TableColumn(
column_name="datetime_col",
type="TIMESTAMP",
table=table,
is_dttm=True,
)
SqlMetric(metric_name="count", expression="count(*)", table=table)
result_object = table.query(
{
"metrics": ["count"],
"is_timeseries": False,
"filter": [],
"from_dttm": datetime(2022, 1, 1),
"to_dttm": datetime(2023, 1, 1),
"granularity": "datetime_col",
}
)
""" >>> result_object.query
SELECT count(*) AS count
FROM
(SELECT '2021-12-31'::timestamp as datetime_col
UNION SELECT '2022-01-01'::timestamp
UNION SELECT '2022-03-10'::timestamp
UNION SELECT '2023-01-01'::timestamp
UNION SELECT '2023-03-10'::timestamp) AS virtual_table
WHERE datetime_col >= TO_TIMESTAMP('2022-01-01 00:00:00.000000', 'YYYY-MM-DD HH24:MI:SS.US')
AND datetime_col < TO_TIMESTAMP('2023-01-01 00:00:00.000000', 'YYYY-MM-DD HH24:MI:SS.US')
"""
assert result_object.df.iloc[0]["count"] == 2


def test_none_operand_in_filter(login_as_admin, physical_dataset):
expected_results = [
Comment on lines +707 to +708
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test case is the new one.

{
"operator": FilterOperator.EQUALS.value,
"count": 10,
"sql_should_contain": "COL4 IS NULL",
},
{
"operator": FilterOperator.NOT_EQUALS.value,
"count": 0,
"sql_should_contain": "COL4 IS NOT NULL",
},
]
for expected in expected_results:
result = physical_dataset.query(
{
"metrics": ["count"],
"filter": [{"col": "col4", "val": None, "op": expected["operator"]}],
"is_timeseries": False,
"filter": [],
"from_dttm": datetime(2022, 1, 1),
"to_dttm": datetime(2023, 1, 1),
"granularity": "datetime_col",
}
)
""" >>> result_object.query
SELECT count(*) AS count
FROM
(SELECT '2021-12-31'::timestamp as datetime_col
UNION SELECT '2022-01-01'::timestamp
UNION SELECT '2022-03-10'::timestamp
UNION SELECT '2023-01-01'::timestamp
UNION SELECT '2023-03-10'::timestamp) AS virtual_table
WHERE datetime_col >= TO_TIMESTAMP('2022-01-01 00:00:00.000000', 'YYYY-MM-DD HH24:MI:SS.US')
AND datetime_col < TO_TIMESTAMP('2023-01-01 00:00:00.000000', 'YYYY-MM-DD HH24:MI:SS.US')
"""
assert result_object.df.iloc[0]["count"] == 2
assert result.df["count"][0] == expected["count"]
assert expected["sql_should_contain"] in result.query.upper()

with pytest.raises(QueryObjectValidationError):
for flt in [
FilterOperator.GREATER_THAN,
FilterOperator.LESS_THAN,
FilterOperator.GREATER_THAN_OR_EQUALS,
FilterOperator.LESS_THAN_OR_EQUALS,
FilterOperator.LIKE,
FilterOperator.ILIKE,
]:
physical_dataset.query(
{
"metrics": ["count"],
"filter": [{"col": "col4", "val": None, "op": flt.value}],
"is_timeseries": False,
}
)


@pytest.mark.parametrize(
Expand Down