Skip to content

Commit

Permalink
feat: generate label map on the backend (#21124)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaoyongjie committed Aug 22, 2022
1 parent 756ed0e commit 11bf7b9
Show file tree
Hide file tree
Showing 8 changed files with 154 additions and 2 deletions.
14 changes: 14 additions & 0 deletions superset/common/query_context_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import copy
import logging
import re
from typing import Any, ClassVar, Dict, List, Optional, TYPE_CHECKING, Union

import numpy as np
Expand Down Expand Up @@ -57,6 +58,7 @@
TIME_COMPARISON,
)
from superset.utils.date_parser import get_past_or_future, normalize_time_delta
from superset.utils.pandas_postprocessing.utils import unescape_separator
from superset.views.utils import get_viz

if TYPE_CHECKING:
Expand Down Expand Up @@ -142,6 +144,17 @@ def get_df_payload(
cache.error_message = str(ex)
cache.status = QueryStatus.FAILED

# the N-dimensional DataFrame has converteds into flat DataFrame
# by `flatten operator`, "comma" in the column is escaped by `escape_separator`
# the result DataFrame columns should be unescaped
label_map = {
unescape_separator(col): [
unescape_separator(col) for col in re.split(r"(?<!\\),\s", col)
]
for col in cache.df.columns.values
}
cache.df.columns = [unescape_separator(col) for col in cache.df.columns.values]

return {
"cache_key": cache_key,
"cached_dttm": cache.cache_dttm,
Expand All @@ -157,6 +170,7 @@ def get_df_payload(
"rowcount": len(cache.df.index),
"from_dttm": query_obj.from_dttm,
"to_dttm": query_obj.to_dttm,
"label_map": label_map,
}

def query_cache_key(self, query_obj: QueryObject, **kwargs: Any) -> Optional[str]:
Expand Down
6 changes: 6 additions & 0 deletions superset/utils/pandas_postprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@
from superset.utils.pandas_postprocessing.rolling import rolling
from superset.utils.pandas_postprocessing.select import select
from superset.utils.pandas_postprocessing.sort import sort
from superset.utils.pandas_postprocessing.utils import (
escape_separator,
unescape_separator,
)

__all__ = [
"aggregate",
Expand All @@ -52,4 +56,6 @@
"select",
"sort",
"flatten",
"escape_separator",
"unescape_separator",
]
5 changes: 3 additions & 2 deletions superset/utils/pandas_postprocessing/flatten.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from superset.utils.pandas_postprocessing.utils import (
_is_multi_index_on_columns,
escape_separator,
FLAT_COLUMN_SEPARATOR,
)

Expand Down Expand Up @@ -86,8 +87,8 @@ def flatten(
_cells = []
for cell in series if is_sequence(series) else [series]:
if pd.notnull(cell):
# every cell should be converted to string
_cells.append(str(cell))
# every cell should be converted to string and escape comma
_cells.append(escape_separator(str(cell)))
_columns.append(FLAT_COLUMN_SEPARATOR.join(_cells))

df.columns = _columns
Expand Down
10 changes: 10 additions & 0 deletions superset/utils/pandas_postprocessing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,3 +198,13 @@ def _append_columns(
return _base_df
append_df = append_df.rename(columns=columns)
return pd.concat([base_df, append_df], axis="columns")


def escape_separator(plain_str: str, sep: str = FLAT_COLUMN_SEPARATOR) -> str:
char = sep.strip()
return plain_str.replace(char, "\\" + char)


def unescape_separator(escaped_str: str, sep: str = FLAT_COLUMN_SEPARATOR) -> str:
char = sep.strip()
return escaped_str.replace("\\" + char, char)
27 changes: 27 additions & 0 deletions tests/integration_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,3 +358,30 @@ def physical_dataset():
for ds in dataset:
db.session.delete(ds)
db.session.commit()


@pytest.fixture
def virtual_dataset_comma_in_column_value():
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn

dataset = SqlaTable(
table_name="virtual_dataset",
sql=(
"SELECT 'col1,row1' as col1, 'col2, row1' as col2 "
"UNION ALL "
"SELECT 'col1,row2' as col1, 'col2, row2' as col2 "
"UNION ALL "
"SELECT 'col1,row3' as col1, 'col2, row3' as col2 "
),
database=get_example_database(),
)
TableColumn(column_name="col1", type="VARCHAR(255)", table=dataset)
TableColumn(column_name="col2", type="VARCHAR(255)", table=dataset)

SqlMetric(metric_name="count", expression="count(*)", table=dataset)
db.session.merge(dataset)

yield dataset

db.session.delete(dataset)
db.session.commit()
45 changes: 45 additions & 0 deletions tests/integration_tests/query_context_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from superset.charts.schemas import ChartDataQueryContextSchema
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
from superset.common.query_context import QueryContext
from superset.common.query_context_factory import QueryContextFactory
from superset.common.query_object import QueryObject
from superset.connectors.sqla.models import SqlMetric
from superset.datasource.dao import DatasourceDAO
Expand All @@ -35,6 +36,7 @@
DatasourceType,
QueryStatus,
)
from superset.utils.pandas_postprocessing.utils import FLAT_COLUMN_SEPARATOR
from tests.integration_tests.base_tests import SupersetTestCase
from tests.integration_tests.fixtures.birth_names_dashboard import (
load_birth_names_dashboard_with_slices,
Expand Down Expand Up @@ -683,3 +685,46 @@ def test_time_offsets_accuracy(self):
row["sum__num__3 years later"]
== df_3_years_later.loc[index]["sum__num"]
)


def test_get_label_map(app_context, virtual_dataset_comma_in_column_value):
qc = QueryContextFactory().create(
datasource={
"type": virtual_dataset_comma_in_column_value.type,
"id": virtual_dataset_comma_in_column_value.id,
},
queries=[
{
"columns": ["col1", "col2"],
"metrics": ["count"],
"post_processing": [
{
"operation": "pivot",
"options": {
"aggregates": {"count": {"operator": "mean"}},
"columns": ["col2"],
"index": ["col1"],
},
},
{"operation": "flatten"},
],
}
],
result_type=ChartDataResultType.FULL,
force=True,
)
query_object = qc.queries[0]
df = qc.get_df_payload(query_object)["df"]
label_map = qc.get_df_payload(query_object)["label_map"]
assert list(df.columns.values) == [
"col1",
"count" + FLAT_COLUMN_SEPARATOR + "col2, row1",
"count" + FLAT_COLUMN_SEPARATOR + "col2, row2",
"count" + FLAT_COLUMN_SEPARATOR + "col2, row3",
]
assert label_map == {
"col1": ["col1"],
"count, col2, row1": ["count", "col2, row1"],
"count, col2, row2": ["count", "col2, row2"],
"count, col2, row3": ["count", "col2, row3"],
}
19 changes: 19 additions & 0 deletions tests/unit_tests/pandas_postprocessing/test_flatten.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,22 @@ def test_flat_integer_column_name():
}
)
)


def test_escape_column_name():
index = pd.to_datetime(["2021-01-01", "2021-01-02", "2021-01-03"])
index.name = "__timestamp"
columns = pd.MultiIndex.from_arrays(
[
["level1,value1", "level1,value2", "level1,value3"],
["level2, value1", "level2, value2", "level2, value3"],
],
names=["level1", "level2"],
)
df = pd.DataFrame(index=index, columns=columns, data=1)
assert list(pp.flatten(df).columns.values) == [
"__timestamp",
"level1\\,value1" + FLAT_COLUMN_SEPARATOR + "level2\\, value1",
"level1\\,value2" + FLAT_COLUMN_SEPARATOR + "level2\\, value2",
"level1\\,value3" + FLAT_COLUMN_SEPARATOR + "level2\\, value3",
]
30 changes: 30 additions & 0 deletions tests/unit_tests/pandas_postprocessing/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from superset.utils.pandas_postprocessing import escape_separator, unescape_separator


def test_escape_separator():
assert escape_separator(r" hell \world ") == r" hell \world "
assert unescape_separator(r" hell \world ") == r" hell \world "

escape_string = escape_separator("hello, world")
assert escape_string == r"hello\, world"
assert unescape_separator(escape_string) == "hello, world"

escape_string = escape_separator("hello,world")
assert escape_string == r"hello\,world"
assert unescape_separator(escape_string) == "hello,world"

0 comments on commit 11bf7b9

Please sign in to comment.