Skip to content

Commit

Permalink
refactor: queryObject - decouple from queryContext and clean code (#1…
Browse files Browse the repository at this point in the history
  • Loading branch information
ofekisr committed Nov 18, 2021
1 parent 56d742f commit b914e2d
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 62 deletions.
16 changes: 9 additions & 7 deletions superset/common/query_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,10 @@ class QueryContext:

datasource: BaseDatasource
queries: List[QueryObject]
force: bool
custom_cache_timeout: Optional[int]
result_type: ChartDataResultType
result_format: ChartDataResultFormat
force: bool
custom_cache_timeout: Optional[int]

# TODO: Type datasource and query_object dictionary with TypedDict when it becomes
# a vanilla python type https://github.com/python/mypy/issues/5288
Expand All @@ -92,19 +92,21 @@ def __init__(
self,
datasource: DatasourceDict,
queries: List[Dict[str, Any]],
force: bool = False,
custom_cache_timeout: Optional[int] = None,
result_type: Optional[ChartDataResultType] = None,
result_format: Optional[ChartDataResultFormat] = None,
force: bool = False,
custom_cache_timeout: Optional[int] = None,
) -> None:
self.datasource = ConnectorRegistry.get_datasource(
str(datasource["type"]), int(datasource["id"]), db.session
)
self.force = force
self.custom_cache_timeout = custom_cache_timeout
self.result_type = result_type or ChartDataResultType.FULL
self.result_format = result_format or ChartDataResultFormat.JSON
self.queries = [QueryObject(self, **query_obj) for query_obj in queries]
self.queries = [
QueryObject(self.result_type, **query_obj) for query_obj in queries
]
self.force = force
self.custom_cache_timeout = custom_cache_timeout
self.cache_values = {
"datasource": datasource,
"queries": queries,
Expand Down
141 changes: 86 additions & 55 deletions superset/common/query_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
from __future__ import annotations

import logging
from datetime import datetime, timedelta
from typing import Any, Dict, List, NamedTuple, Optional, TYPE_CHECKING
Expand Down Expand Up @@ -106,11 +108,12 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
series_limit_metric: Optional[Metric]
time_offsets: List[str]
time_shift: Optional[timedelta]
time_range: Optional[str]
to_dttm: Optional[datetime]

def __init__( # pylint: disable=too-many-arguments,too-many-locals
self,
query_context: "QueryContext",
parent_result_type: ChartDataResultType,
annotation_layers: Optional[List[Dict[str, Any]]] = None,
applied_time_extras: Optional[Dict[str, str]] = None,
apply_fetch_values_predicate: bool = False,
Expand All @@ -125,7 +128,6 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals
order_desc: bool = True,
orderby: Optional[List[OrderBy]] = None,
post_processing: Optional[List[Optional[Dict[str, Any]]]] = None,
result_type: Optional[ChartDataResultType] = None,
row_limit: Optional[int] = None,
row_offset: Optional[int] = None,
series_columns: Optional[List[Column]] = None,
Expand All @@ -135,88 +137,117 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals
time_shift: Optional[str] = None,
**kwargs: Any,
):
columns = columns or []
extras = extras or {}
annotation_layers = annotation_layers or []
self.result_type = kwargs.get("result_type", parent_result_type)
self._set_annotation_layers(annotation_layers)
self.applied_time_extras = applied_time_extras or {}
self.apply_fetch_values_predicate = apply_fetch_values_predicate or False
self.columns = columns or []
self._set_datasource(datasource)
self._set_extras(extras)
self.filter = filters or []
self.granularity = granularity
self.is_rowcount = is_rowcount
self._set_is_timeseries(is_timeseries)
self._set_metrics(metrics)
self.order_desc = order_desc
self.orderby = orderby or []
self._set_post_processing(post_processing)
self._set_row_limit(row_limit)
self.row_offset = row_offset or 0
self._init_series_columns(series_columns, metrics, is_timeseries)
self.series_limit = series_limit
self.series_limit_metric = series_limit_metric
self.set_dttms(time_range, time_shift)
self.time_range = time_range
self.time_shift = parse_human_timedelta(time_shift)
self.time_offsets = kwargs.get("time_offsets", [])
self.inner_from_dttm = kwargs.get("inner_from_dttm")
self.inner_to_dttm = kwargs.get("inner_to_dttm")
if series_columns:
self.series_columns = series_columns
elif is_timeseries and metrics:
self.series_columns = columns
else:
self.series_columns = []
self._rename_deprecated_fields(kwargs)
self._move_deprecated_extra_fields(kwargs)

self.is_rowcount = is_rowcount
self.datasource = None
if datasource:
self.datasource = ConnectorRegistry.get_datasource(
str(datasource["type"]), int(datasource["id"]), db.session
)
self.result_type = result_type or query_context.result_type
self.apply_fetch_values_predicate = apply_fetch_values_predicate or False
def _set_annotation_layers(
self, annotation_layers: Optional[List[Dict[str, Any]]]
) -> None:
self.annotation_layers = [
layer
for layer in annotation_layers
for layer in (annotation_layers or [])
# formula annotations don't affect the payload, hence can be dropped
if layer["annotationType"] != "FORMULA"
]
self.applied_time_extras = applied_time_extras or {}
self.granularity = granularity
self.from_dttm, self.to_dttm = get_since_until(
relative_start=extras.get(
"relative_start", config["DEFAULT_RELATIVE_START_TIME"]
),
relative_end=extras.get(
"relative_end", config["DEFAULT_RELATIVE_END_TIME"]
),
time_range=time_range,
time_shift=time_shift,
)

def _set_datasource(self, datasource: Optional[DatasourceDict]) -> None:
self.datasource = None
if datasource:
self.datasource = ConnectorRegistry.get_datasource(
str(datasource["type"]), int(datasource["id"]), db.session
)

def _set_extras(self, extras: Optional[Dict[str, Any]]) -> None:
self.extras = extras or {}
if config["SIP_15_ENABLED"]:
self.extras["time_range_endpoints"] = get_time_range_endpoints(
form_data=self.extras
)

def _set_is_timeseries(self, is_timeseries: Optional[bool]) -> None:
# is_timeseries is True if time column is in either columns or groupby
# (both are dimensions)
self.is_timeseries = (
is_timeseries if is_timeseries is not None else DTTM_ALIAS in columns
is_timeseries if is_timeseries is not None else DTTM_ALIAS in self.columns
)
self.time_range = time_range
self.time_shift = parse_human_timedelta(time_shift)
self.post_processing = [
post_proc for post_proc in post_processing or [] if post_proc
]

def _set_metrics(self, metrics: Optional[List[Metric]] = None) -> None:
# Support metric reference/definition in the format of
# 1. 'metric_name' - name of predefined metric
# 2. { label: 'label_name' } - legacy format for a predefined metric
# 3. { expressionType: 'SIMPLE' | 'SQL', ... } - adhoc metric
def is_str_or_adhoc(metric: Metric) -> bool:
return isinstance(metric, str) or is_adhoc_metric(metric)

self.metrics = metrics and [
x if isinstance(x, str) or is_adhoc_metric(x) else x["label"] # type: ignore
for x in metrics
x if is_str_or_adhoc(x) else x["label"] for x in metrics # type: ignore
]

def _set_post_processing(
self, post_processing: Optional[List[Optional[Dict[str, Any]]]]
) -> None:
self.post_processing = [
post_proc for post_proc in post_processing or [] if post_proc
]

def _set_row_limit(self, row_limit: Optional[int]) -> None:
default_row_limit = (
config["SAMPLES_ROW_LIMIT"]
if self.result_type == ChartDataResultType.SAMPLES
else config["ROW_LIMIT"]
)
self.row_limit = apply_max_row_limit(row_limit or default_row_limit)
self.row_offset = row_offset or 0
self.filter = filters or []
self.series_limit = series_limit
self.series_limit_metric = series_limit_metric
self.order_desc = order_desc
self.extras = extras

if config["SIP_15_ENABLED"]:
self.extras["time_range_endpoints"] = get_time_range_endpoints(
form_data=self.extras
)

self.columns = columns
self.orderby = orderby or []
def _init_series_columns(
self,
series_columns: Optional[List[Column]],
metrics: Optional[List[Metric]],
is_timeseries: Optional[bool],
) -> None:
if series_columns:
self.series_columns = series_columns
elif is_timeseries and metrics:
self.series_columns = self.columns
else:
self.series_columns = []

self._rename_deprecated_fields(kwargs)
self._move_deprecated_extra_fields(kwargs)
def set_dttms(self, time_range: Optional[str], time_shift: Optional[str]) -> None:
self.from_dttm, self.to_dttm = get_since_until(
relative_start=self.extras.get(
"relative_start", config["DEFAULT_RELATIVE_START_TIME"]
),
relative_end=self.extras.get(
"relative_end", config["DEFAULT_RELATIVE_END_TIME"]
),
time_range=time_range,
time_shift=time_shift,
)

def _rename_deprecated_fields(self, kwargs: Dict[str, Any]) -> None:
# rename deprecated fields
Expand Down

0 comments on commit b914e2d

Please sign in to comment.