Skip to content

Commit

Permalink
fix: Always use temporal type for dttm columns [ID-2] (#17458)
Browse files Browse the repository at this point in the history
* fix: Always use temporal type for dttm columns

* move inference and implement in chart postproc

* fix test

* fix test case

Co-authored-by: Ville Brofeldt <ville.v.brofeldt@gmail.com>
  • Loading branch information
kgabryje and villebro committed Nov 22, 2021
1 parent 66d7569 commit 1f8eff7
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 11 deletions.
15 changes: 11 additions & 4 deletions superset/charts/data/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from superset.charts.post_processing import apply_post_process
from superset.charts.schemas import ChartDataQueryContextSchema
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
from superset.connectors.base.models import BaseDatasource
from superset.exceptions import QueryObjectValidationError
from superset.extensions import event_logger
from superset.utils.async_query_manager import AsyncQueryTokenException
Expand Down Expand Up @@ -158,7 +159,9 @@ def get_data(self, pk: int) -> Response:
except (TypeError, json.decoder.JSONDecodeError):
form_data = {}

return self._get_data_response(command, form_data=form_data)
return self._get_data_response(
command=command, form_data=form_data, datasource=query_context.datasource
)

@expose("/data", methods=["POST"])
@protect()
Expand Down Expand Up @@ -327,7 +330,10 @@ def _run_async(
return self.response(202, **result)

def _send_chart_response(
self, result: Dict[Any, Any], form_data: Optional[Dict[str, Any]] = None,
self,
result: Dict[Any, Any],
form_data: Optional[Dict[str, Any]] = None,
datasource: Optional[BaseDatasource] = None,
) -> Response:
result_type = result["query_context"].result_type
result_format = result["query_context"].result_format
Expand All @@ -336,7 +342,7 @@ def _send_chart_response(
# This is needed for sending reports based on text charts that do the
# post-processing of data, eg, the pivot table.
if result_type == ChartDataResultType.POST_PROCESSED:
result = apply_post_process(result, form_data)
result = apply_post_process(result, form_data, datasource)

if result_format == ChartDataResultFormat.CSV:
# Verify user has permission to export CSV file
Expand Down Expand Up @@ -364,6 +370,7 @@ def _get_data_response(
command: ChartDataCommand,
force_cached: bool = False,
form_data: Optional[Dict[str, Any]] = None,
datasource: Optional[BaseDatasource] = None,
) -> Response:
try:
result = command.run(force_cached=force_cached)
Expand All @@ -372,7 +379,7 @@ def _get_data_response(
except ChartDataQueryFailedError as exc:
return self.response_400(message=exc.message)

return self._send_chart_response(result, form_data)
return self._send_chart_response(result, form_data, datasource)

# pylint: disable=invalid-name, no-self-use
def _load_query_context_form_from_cache(self, cache_key: str) -> Dict[str, Any]:
Expand Down
11 changes: 8 additions & 3 deletions superset/charts/post_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,16 @@
"""

from io import StringIO
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING

import pandas as pd

from superset.common.chart_data import ChartDataResultFormat
from superset.utils.core import DTTM_ALIAS, extract_dataframe_dtypes, get_metric_name

if TYPE_CHECKING:
from superset.connectors.base.models import BaseDatasource


def get_column_key(label: Tuple[str, ...], metrics: List[str]) -> Tuple[Any, ...]:
"""
Expand Down Expand Up @@ -284,7 +287,9 @@ def table(df: pd.DataFrame, form_data: Dict[str, Any]) -> pd.DataFrame:


def apply_post_process(
result: Dict[Any, Any], form_data: Optional[Dict[str, Any]] = None,
result: Dict[Any, Any],
form_data: Optional[Dict[str, Any]] = None,
datasource: Optional["BaseDatasource"] = None,
) -> Dict[Any, Any]:
form_data = form_data or {}

Expand All @@ -306,7 +311,7 @@ def apply_post_process(

query["colnames"] = list(processed_df.columns)
query["indexnames"] = list(processed_df.index)
query["coltypes"] = extract_dataframe_dtypes(processed_df)
query["coltypes"] = extract_dataframe_dtypes(processed_df, datasource)
query["rowcount"] = len(processed_df.index)

# Flatten hierarchical columns/index since they are represented as
Expand Down
2 changes: 1 addition & 1 deletion superset/common/query_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def _get_full(
if status != QueryStatus.FAILED:
payload["colnames"] = list(df.columns)
payload["indexnames"] = list(df.index)
payload["coltypes"] = extract_dataframe_dtypes(df)
payload["coltypes"] = extract_dataframe_dtypes(df, datasource)
payload["data"] = query_context.get_data(df)
payload["result_format"] = query_context.result_format
del payload["df"]
Expand Down
16 changes: 14 additions & 2 deletions superset/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1597,7 +1597,9 @@ def get_column_names_from_metrics(metrics: List[Metric]) -> List[str]:
return [col for col in map(get_column_name_from_metric, metrics) if col]


def extract_dataframe_dtypes(df: pd.DataFrame) -> List[GenericDataType]:
def extract_dataframe_dtypes(
df: pd.DataFrame, datasource: Optional["BaseDatasource"] = None,
) -> List[GenericDataType]:
"""Serialize pandas/numpy dtypes to generic types"""

# omitting string types as those will be the default type
Expand All @@ -1612,11 +1614,21 @@ def extract_dataframe_dtypes(df: pd.DataFrame) -> List[GenericDataType]:
"date": GenericDataType.TEMPORAL,
}

columns_by_name = (
{column.column_name: column for column in datasource.columns}
if datasource
else {}
)
generic_types: List[GenericDataType] = []
for column in df.columns:
column_object = columns_by_name.get(column)
series = df[column]
inferred_type = infer_dtype(series)
generic_type = inferred_type_map.get(inferred_type, GenericDataType.STRING)
generic_type = (
GenericDataType.TEMPORAL
if column_object and column_object.is_dttm
else inferred_type_map.get(inferred_type, GenericDataType.STRING)
)
generic_types.append(generic_type)

return generic_types
Expand Down
7 changes: 6 additions & 1 deletion tests/integration_tests/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1121,7 +1121,9 @@ def test_get_form_data_token(self):
generated_token = get_form_data_token({})
assert re.match(r"^token_[a-z0-9]{8}$", generated_token) is not None

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_extract_dataframe_dtypes(self):
slc = self.get_slice("Girls", db.session)
cols: Tuple[Tuple[str, GenericDataType, List[Any]], ...] = (
("dt", GenericDataType.TEMPORAL, [date(2021, 2, 4), date(2021, 2, 4)]),
(
Expand All @@ -1147,10 +1149,13 @@ def test_extract_dataframe_dtypes(self):
("float_null", GenericDataType.NUMERIC, [None, 0.5]),
("bool_null", GenericDataType.BOOLEAN, [None, False]),
("obj_null", GenericDataType.STRING, [None, {"a": 1}]),
# Non-timestamp columns should be identified as temporal if
# `is_dttm` is set to `True` in the underlying datasource
("ds", GenericDataType.TEMPORAL, [None, {"ds": "2017-01-01"}]),
)

df = pd.DataFrame(data={col[0]: col[2] for col in cols})
assert extract_dataframe_dtypes(df) == [col[1] for col in cols]
assert extract_dataframe_dtypes(df, slc.datasource) == [col[1] for col in cols]

def test_normalize_dttm_col(self):
def normalize_col(
Expand Down

0 comments on commit 1f8eff7

Please sign in to comment.