Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Always use temporal type for dttm columns [ID-2] #17458

Merged
merged 4 commits into from
Nov 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions superset/charts/data/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from superset.charts.data.query_context_cache_loader import QueryContextCacheLoader
from superset.charts.post_processing import apply_post_process
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
from superset.connectors.base.models import BaseDatasource
from superset.exceptions import QueryObjectValidationError
from superset.extensions import event_logger
from superset.utils.async_query_manager import AsyncQueryTokenException
Expand Down Expand Up @@ -155,7 +156,9 @@ def get_data(self, pk: int) -> Response:
except (TypeError, json.decoder.JSONDecodeError):
form_data = {}

return self._get_data_response(command, form_data=form_data)
return self._get_data_response(
command=command, form_data=form_data, datasource=query_context.datasource
)

@expose("/data", methods=["POST"])
@protect()
Expand Down Expand Up @@ -324,7 +327,10 @@ def _run_async(
return self.response(202, **result)

def _send_chart_response(
self, result: Dict[Any, Any], form_data: Optional[Dict[str, Any]] = None,
self,
result: Dict[Any, Any],
form_data: Optional[Dict[str, Any]] = None,
datasource: Optional[BaseDatasource] = None,
) -> Response:
result_type = result["query_context"].result_type
result_format = result["query_context"].result_format
Expand All @@ -333,7 +339,7 @@ def _send_chart_response(
# This is needed for sending reports based on text charts that do the
# post-processing of data, eg, the pivot table.
if result_type == ChartDataResultType.POST_PROCESSED:
result = apply_post_process(result, form_data)
result = apply_post_process(result, form_data, datasource)

if result_format == ChartDataResultFormat.CSV:
# Verify user has permission to export CSV file
Expand Down Expand Up @@ -361,6 +367,7 @@ def _get_data_response(
command: ChartDataCommand,
force_cached: bool = False,
form_data: Optional[Dict[str, Any]] = None,
datasource: Optional[BaseDatasource] = None,
) -> Response:
try:
result = command.run(force_cached=force_cached)
Expand All @@ -369,7 +376,7 @@ def _get_data_response(
except ChartDataQueryFailedError as exc:
return self.response_400(message=exc.message)

return self._send_chart_response(result, form_data)
return self._send_chart_response(result, form_data, datasource)

# pylint: disable=invalid-name, no-self-use
def _load_query_context_form_from_cache(self, cache_key: str) -> Dict[str, Any]:
Expand Down
11 changes: 8 additions & 3 deletions superset/charts/post_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,16 @@
"""

from io import StringIO
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING

import pandas as pd

from superset.common.chart_data import ChartDataResultFormat
from superset.utils.core import DTTM_ALIAS, extract_dataframe_dtypes, get_metric_name

if TYPE_CHECKING:
from superset.connectors.base.models import BaseDatasource


def get_column_key(label: Tuple[str, ...], metrics: List[str]) -> Tuple[Any, ...]:
"""
Expand Down Expand Up @@ -284,7 +287,9 @@ def table(df: pd.DataFrame, form_data: Dict[str, Any]) -> pd.DataFrame:


def apply_post_process(
result: Dict[Any, Any], form_data: Optional[Dict[str, Any]] = None,
result: Dict[Any, Any],
form_data: Optional[Dict[str, Any]] = None,
datasource: Optional["BaseDatasource"] = None,
) -> Dict[Any, Any]:
form_data = form_data or {}

Expand All @@ -306,7 +311,7 @@ def apply_post_process(

query["colnames"] = list(processed_df.columns)
query["indexnames"] = list(processed_df.index)
query["coltypes"] = extract_dataframe_dtypes(processed_df)
query["coltypes"] = extract_dataframe_dtypes(processed_df, datasource)
query["rowcount"] = len(processed_df.index)

# Flatten hierarchical columns/index since they are represented as
Expand Down
2 changes: 1 addition & 1 deletion superset/common/query_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def _get_full(
if status != QueryStatus.FAILED:
payload["colnames"] = list(df.columns)
payload["indexnames"] = list(df.index)
payload["coltypes"] = extract_dataframe_dtypes(df)
payload["coltypes"] = extract_dataframe_dtypes(df, datasource)
payload["data"] = query_context.get_data(df)
payload["result_format"] = query_context.result_format
del payload["df"]
Expand Down
16 changes: 14 additions & 2 deletions superset/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1597,7 +1597,9 @@ def get_column_names_from_metrics(metrics: List[Metric]) -> List[str]:
return [col for col in map(get_column_name_from_metric, metrics) if col]


def extract_dataframe_dtypes(df: pd.DataFrame) -> List[GenericDataType]:
def extract_dataframe_dtypes(
df: pd.DataFrame, datasource: Optional["BaseDatasource"] = None,
) -> List[GenericDataType]:
"""Serialize pandas/numpy dtypes to generic types"""

# omitting string types as those will be the default type
Expand All @@ -1612,11 +1614,21 @@ def extract_dataframe_dtypes(df: pd.DataFrame) -> List[GenericDataType]:
"date": GenericDataType.TEMPORAL,
}

columns_by_name = (
{column.column_name: column for column in datasource.columns}
if datasource
else {}
)
generic_types: List[GenericDataType] = []
for column in df.columns:
column_object = columns_by_name.get(column)
series = df[column]
inferred_type = infer_dtype(series)
generic_type = inferred_type_map.get(inferred_type, GenericDataType.STRING)
generic_type = (
GenericDataType.TEMPORAL
if column_object and column_object.is_dttm
else inferred_type_map.get(inferred_type, GenericDataType.STRING)
)
generic_types.append(generic_type)

return generic_types
Expand Down
7 changes: 6 additions & 1 deletion tests/integration_tests/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1121,7 +1121,9 @@ def test_get_form_data_token(self):
generated_token = get_form_data_token({})
assert re.match(r"^token_[a-z0-9]{8}$", generated_token) is not None

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_extract_dataframe_dtypes(self):
slc = self.get_slice("Girls", db.session)
cols: Tuple[Tuple[str, GenericDataType, List[Any]], ...] = (
("dt", GenericDataType.TEMPORAL, [date(2021, 2, 4), date(2021, 2, 4)]),
(
Expand All @@ -1147,10 +1149,13 @@ def test_extract_dataframe_dtypes(self):
("float_null", GenericDataType.NUMERIC, [None, 0.5]),
("bool_null", GenericDataType.BOOLEAN, [None, False]),
("obj_null", GenericDataType.STRING, [None, {"a": 1}]),
# Non-timestamp columns should be identified as temporal if
# `is_dttm` is set to `True` in the underlying datasource
("ds", GenericDataType.TEMPORAL, [None, {"ds": "2017-01-01"}]),
)

df = pd.DataFrame(data={col[0]: col[2] for col in cols})
assert extract_dataframe_dtypes(df) == [col[1] for col in cols]
assert extract_dataframe_dtypes(df, slc.datasource) == [col[1] for col in cols]

def test_normalize_dttm_col(self):
def normalize_col(
Expand Down