From 9447381549bf930ba850694f854854d77747f43e Mon Sep 17 00:00:00 2001 From: Ville Brofeldt <33317356+villebro@users.noreply.github.com> Date: Tue, 14 Apr 2020 12:06:10 +0300 Subject: [PATCH] deprecate groupby controls in query_obj (#9366) * Deprecate groupby from query_obj * Fix query_object bug * Fix histogram * Remove groupby from legacy druid connector and fix first batch of unit tests * Deprecate some unnecessary tests + fix a few others * Address comments * hide SIP-38 changes behind feature flag * Break out further SIP-38 related tests * Reslove test errors * Add feature flag to QueryContext * Resolve tests and bad rebase * Backport recent changes from viz.py and fix broken DeckGL charts * Fix bad rebase * backport #9522 and address comments --- superset/common/query_object.py | 20 +- superset/config.py | 1 + superset/connectors/druid/models.py | 55 +- superset/connectors/sqla/models.py | 32 +- superset/models/slice.py | 6 +- superset/views/utils.py | 8 +- superset/viz_sip38.py | 2852 +++++++++++++++++++++++++++ tests/druid_func_tests.py | 20 +- tests/druid_func_tests_sip38.py | 1157 +++++++++++ 9 files changed, 4110 insertions(+), 41 deletions(-) create mode 100644 superset/viz_sip38.py create mode 100644 tests/druid_func_tests_sip38.py diff --git a/superset/common/query_object.py b/superset/common/query_object.py index 72e9dfef8f51..63158c6751bb 100644 --- a/superset/common/query_object.py +++ b/superset/common/query_object.py @@ -16,6 +16,7 @@ # under the License. # pylint: disable=R import hashlib +import logging from datetime import datetime, timedelta from typing import Any, Dict, List, Optional, Union @@ -23,11 +24,13 @@ from flask_babel import gettext as _ from pandas import DataFrame -from superset import app +from superset import app, is_feature_enabled from superset.exceptions import QueryObjectValidationError from superset.utils import core as utils, pandas_postprocessing from superset.views.utils import get_time_range_endpoints +logger = logging.getLogger(__name__) + # TODO: Type Metrics dictionary with TypedDict when it becomes a vanilla python type # https://github.com/python/mypy/issues/5288 @@ -75,6 +78,7 @@ def __init__( relative_start: str = app.config["DEFAULT_RELATIVE_START_TIME"], relative_end: str = app.config["DEFAULT_RELATIVE_END_TIME"], ): + is_sip_38 = is_feature_enabled("SIP_38_VIZ_REARCHITECTURE") self.granularity = granularity self.from_dttm, self.to_dttm = utils.get_since_until( relative_start=relative_start, @@ -85,8 +89,9 @@ def __init__( self.is_timeseries = is_timeseries self.time_range = time_range self.time_shift = utils.parse_human_timedelta(time_shift) - self.groupby = groupby or [] self.post_processing = post_processing or [] + if not is_sip_38: + self.groupby = groupby or [] # Temporary solution for backward compatibility issue due the new format of # non-ad-hoc metric which needs to adhere to superset-ui per @@ -107,6 +112,13 @@ def __init__( self.extras["time_range_endpoints"] = get_time_range_endpoints(form_data={}) self.columns = columns or [] + if is_sip_38 and groupby: + self.columns += groupby + logger.warning( + f"The field groupby is deprecated. Viz plugins should " + f"pass all selectables via the columns field" + ) + self.orderby = orderby or [] def to_dict(self) -> Dict[str, Any]: @@ -115,7 +127,6 @@ def to_dict(self) -> Dict[str, Any]: "from_dttm": self.from_dttm, "to_dttm": self.to_dttm, "is_timeseries": self.is_timeseries, - "groupby": self.groupby, "metrics": self.metrics, "row_limit": self.row_limit, "filter": self.filter, @@ -126,6 +137,9 @@ def to_dict(self) -> Dict[str, Any]: "columns": self.columns, "orderby": self.orderby, } + if not is_feature_enabled("SIP_38_VIZ_REARCHITECTURE"): + query_object_dict["groupby"] = self.groupby + return query_object_dict def cache_key(self, **extra: Any) -> str: diff --git a/superset/config.py b/superset/config.py index 28beb276a31a..91866fbfe2d0 100644 --- a/superset/config.py +++ b/superset/config.py @@ -287,6 +287,7 @@ def _try_json_readsha(filepath, length): # pylint: disable=unused-argument "PRESTO_EXPAND_DATA": False, "REDUCE_DASHBOARD_BOOTSTRAP_PAYLOAD": False, "SHARE_QUERIES_VIA_KV_STORE": False, + "SIP_38_VIZ_REARCHITECTURE": False, "TAGGING_SYSTEM": False, "SQLLAB_BACKEND_PERSISTENCE": False, "LIST_VIEWS_NEW_UI": False, diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py index 73ca6d5f2cf6..a4cc52748a68 100644 --- a/superset/connectors/druid/models.py +++ b/superset/connectors/druid/models.py @@ -48,7 +48,7 @@ from sqlalchemy.orm import backref, relationship, Session from sqlalchemy_utils import EncryptedType -from superset import conf, db, security_manager +from superset import conf, db, is_feature_enabled, security_manager from superset.connectors.base.models import BaseColumn, BaseDatasource, BaseMetric from superset.constants import NULL_STRING from superset.exceptions import SupersetException @@ -84,6 +84,7 @@ except ImportError: pass +IS_SIP_38 = is_feature_enabled("SIP_38_VIZ_REARCHITECTURE") DRUID_TZ = conf.get("DRUID_TZ") POST_AGG_TYPE = "postagg" metadata = Model.metadata # pylint: disable=no-member @@ -1082,11 +1083,11 @@ def get_aggregations( return aggregations def get_dimensions( - self, groupby: List[str], columns_dict: Dict[str, DruidColumn] + self, columns: List[str], columns_dict: Dict[str, DruidColumn] ) -> List[Union[str, Dict]]: dimensions = [] - groupby = [gb for gb in groupby if gb in columns_dict] - for column_name in groupby: + columns = [col for col in columns if col in columns_dict] + for column_name in columns: col = columns_dict.get(column_name) dim_spec = col.dimension_spec if col else None dimensions.append(dim_spec or column_name) @@ -1137,11 +1138,12 @@ def sanitize_metric_object(metric: Dict) -> None: def run_query( # druid self, - groupby, metrics, granularity, from_dttm, to_dttm, + columns=None, + groupby=None, filter=None, is_timeseries=True, timeseries_limit=None, @@ -1151,7 +1153,6 @@ def run_query( # druid inner_to_dttm=None, orderby=None, extras=None, - columns=None, phase=2, client=None, order_desc=True, @@ -1188,7 +1189,11 @@ def run_query( # druid ) # the dimensions list with dimensionSpecs expanded - dimensions = self.get_dimensions(groupby, columns_dict) + + dimensions = self.get_dimensions( + columns if IS_SIP_38 else groupby, columns_dict + ) + extras = extras or {} qry = dict( datasource=self.datasource_name, @@ -1214,7 +1219,9 @@ def run_query( # druid order_direction = "descending" if order_desc else "ascending" - if columns: + if (IS_SIP_38 and not metrics and "__time" not in columns) or ( + not IS_SIP_38 and columns + ): columns.append("__time") del qry["post_aggregations"] del qry["aggregations"] @@ -1224,11 +1231,20 @@ def run_query( # druid qry["granularity"] = "all" qry["limit"] = row_limit client.scan(**qry) - elif len(groupby) == 0 and not having_filters: + elif (IS_SIP_38 and columns) or ( + not IS_SIP_38 and len(groupby) == 0 and not having_filters + ): logger.info("Running timeseries query for no groupby values") del qry["dimensions"] client.timeseries(**qry) - elif not having_filters and len(groupby) == 1 and order_desc: + elif ( + not having_filters + and order_desc + and ( + (IS_SIP_38 and len(columns) == 1) + or (not IS_SIP_38 and len(groupby) == 1) + ) + ): dim = list(qry["dimensions"])[0] logger.info("Running two-phase topn query for dimension [{}]".format(dim)) pre_qry = deepcopy(qry) @@ -1279,7 +1295,10 @@ def run_query( # druid qry["metric"] = list(qry["aggregations"].keys())[0] client.topn(**qry) logger.info("Phase 2 Complete") - elif len(groupby) > 0 or having_filters: + elif ( + having_filters + or ((IS_SIP_38 and columns) or (not IS_SIP_38 and len(groupby))) > 0 + ): # If grouping on multiple fields or using a having filter # we have to force a groupby query logger.info("Running groupby query for dimensions [{}]".format(dimensions)) @@ -1364,8 +1383,8 @@ def run_query( # druid return query_str @staticmethod - def homogenize_types(df: pd.DataFrame, groupby_cols: Iterable[str]) -> pd.DataFrame: - """Converting all GROUPBY columns to strings + def homogenize_types(df: pd.DataFrame, columns: Iterable[str]) -> pd.DataFrame: + """Converting all columns to strings When grouping by a numeric (say FLOAT) column, pydruid returns strings in the dataframe. This creates issues downstream related @@ -1374,7 +1393,7 @@ def homogenize_types(df: pd.DataFrame, groupby_cols: Iterable[str]) -> pd.DataFr Here we replace None with and make the whole series a str instead of an object. """ - df[groupby_cols] = df[groupby_cols].fillna(NULL_STRING).astype("unicode") + df[columns] = df[columns].fillna(NULL_STRING).astype("unicode") return df def query(self, query_obj: Dict) -> QueryResult: @@ -1390,7 +1409,9 @@ def query(self, query_obj: Dict) -> QueryResult: df=df, query=query_str, duration=datetime.now() - qry_start_dttm ) - df = self.homogenize_types(df, query_obj.get("groupby", [])) + df = self.homogenize_types( + df, query_obj.get("columns" if IS_SIP_38 else "groupby", []) + ) df.columns = [ DTTM_ALIAS if c in ("timestamp", "__time") else c for c in df.columns ] @@ -1405,7 +1426,9 @@ def query(self, query_obj: Dict) -> QueryResult: cols: List[str] = [] if DTTM_ALIAS in df.columns: cols += [DTTM_ALIAS] - cols += query_obj.get("groupby") or [] + + if not IS_SIP_38: + cols += query_obj.get("groupby") or [] cols += query_obj.get("columns") or [] cols += query_obj.get("metrics") or [] diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 642997a10554..c290363dd8c5 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -49,7 +49,7 @@ from sqlalchemy.sql import column, ColumnElement, literal_column, table, text from sqlalchemy.sql.expression import Label, Select, TextAsFrom -from superset import app, db, security_manager +from superset import app, db, is_feature_enabled, security_manager from superset.connectors.base.models import BaseColumn, BaseDatasource, BaseMetric from superset.constants import NULL_STRING from superset.db_engine_specs.base import TimestampExpression @@ -696,11 +696,12 @@ def _get_sqla_row_level_filters(self, template_processor) -> List[str]: def get_sqla_query( # sqla self, - groupby, metrics, granularity, from_dttm, to_dttm, + columns=None, + groupby=None, filter=None, is_timeseries=True, timeseries_limit=15, @@ -710,7 +711,6 @@ def get_sqla_query( # sqla inner_to_dttm=None, orderby=None, extras=None, - columns=None, order_desc=True, ) -> SqlaQuery: """Querying any sqla table from this common interface""" @@ -723,6 +723,7 @@ def get_sqla_query( # sqla "filter": filter, "columns": {col.column_name: col for col in self.columns}, } + is_sip_38 = is_feature_enabled("SIP_38_VIZ_REARCHITECTURE") template_kwargs.update(self.template_params_dict) extra_cache_keys: List[Any] = [] template_kwargs["extra_cache_keys"] = extra_cache_keys @@ -749,7 +750,11 @@ def get_sqla_query( # sqla "and is required by this type of chart" ) ) - if not groupby and not metrics and not columns: + if ( + not metrics + and not columns + and (is_sip_38 or (not is_sip_38 and not groupby)) + ): raise Exception(_("Empty query?")) metrics_exprs: List[ColumnElement] = [] for m in metrics: @@ -768,9 +773,9 @@ def get_sqla_query( # sqla select_exprs: List[Column] = [] groupby_exprs_sans_timestamp: OrderedDict = OrderedDict() - if groupby: + if (is_sip_38 and metrics and columns) or (not is_sip_38 and groupby): # dedup columns while preserving order - groupby = list(dict.fromkeys(groupby)) + groupby = list(dict.fromkeys(columns if is_sip_38 else groupby)) select_exprs = [] for s in groupby: @@ -829,7 +834,7 @@ def get_sqla_query( # sqla tbl = self.get_from_clause(template_processor) - if not columns: + if (is_sip_38 and metrics) or (not is_sip_38 and not columns): qry = qry.group_by(*groupby_exprs_with_timestamp.values()) where_clause_and = [] @@ -892,7 +897,7 @@ def get_sqla_query( # sqla qry = qry.where(and_(*where_clause_and)) qry = qry.having(and_(*having_clause_and)) - if not orderby and not columns: + if not orderby and ((is_sip_38 and metrics) or (not is_sip_38 and not columns)): orderby = [(main_metric_expr, not order_desc)] # To ensure correct handling of the ORDER BY labeling we need to reference the @@ -914,7 +919,12 @@ def get_sqla_query( # sqla if row_limit: qry = qry.limit(row_limit) - if is_timeseries and timeseries_limit and groupby and not time_groupby_inline: + if ( + is_timeseries + and timeseries_limit + and not time_groupby_inline + and ((is_sip_38 and columns) or (not is_sip_38 and groupby)) + ): if self.database.db_engine_spec.allows_joins: # some sql dialects require for order by expressions # to also be in the select clause -- others, e.g. vertica, @@ -972,7 +982,6 @@ def get_sqla_query( # sqla prequery_obj = { "is_timeseries": False, "row_limit": timeseries_limit, - "groupby": groupby, "metrics": metrics, "granularity": granularity, "from_dttm": inner_from_dttm or from_dttm, @@ -983,6 +992,9 @@ def get_sqla_query( # sqla "columns": columns, "order_desc": True, } + if not is_sip_38: + prequery_obj["groupby"] = groupby + result = self.query(prequery_obj) prequeries.append(result.query) dimensions = [ diff --git a/superset/models/slice.py b/superset/models/slice.py index 6cfadab07a95..59544407c8bf 100644 --- a/superset/models/slice.py +++ b/superset/models/slice.py @@ -31,7 +31,11 @@ from superset.models.helpers import AuditMixinNullable, ImportMixin from superset.models.tags import ChartUpdater from superset.utils import core as utils -from superset.viz import BaseViz, viz_types + +if is_feature_enabled("SIP_38_VIZ_REARCHITECTURE"): + from superset.viz_sip38 import BaseViz, viz_types # type: ignore +else: + from superset.viz import BaseViz, viz_types # type: ignore if TYPE_CHECKING: # pylint: disable=unused-import diff --git a/superset/views/utils.py b/superset/views/utils.py index 0f8629df87b8..edb2987ab429 100644 --- a/superset/views/utils.py +++ b/superset/views/utils.py @@ -23,7 +23,7 @@ from flask import request import superset.models.core as models -from superset import app, db, viz +from superset import app, db, is_feature_enabled from superset.connectors.connector_registry import ConnectorRegistry from superset.exceptions import SupersetException from superset.legacy import update_time_range @@ -31,6 +31,12 @@ from superset.models.slice import Slice from superset.utils.core import QueryStatus, TimeRangeEndpoint +if is_feature_enabled("SIP_38_VIZ_REARCHITECTURE"): + from superset import viz_sip38 as viz # type: ignore +else: + from superset import viz # type: ignore + + FORM_DATA_KEY_BLACKLIST: List[str] = [] if not app.config["ENABLE_JAVASCRIPT_CONTROLS"]: FORM_DATA_KEY_BLACKLIST = ["js_tooltip", "js_onclick_href", "js_data_mutator"] diff --git a/superset/viz_sip38.py b/superset/viz_sip38.py new file mode 100644 index 000000000000..1992bd1c14bb --- /dev/null +++ b/superset/viz_sip38.py @@ -0,0 +1,2852 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=C,R,W +"""This module contains the 'Viz' objects + +These objects represent the backend of all the visualizations that +Superset can render. +""" +import copy +import hashlib +import inspect +import logging +import math +import pickle as pkl +import re +import uuid +from collections import defaultdict, OrderedDict +from datetime import datetime, timedelta +from itertools import product +from typing import Any, Dict, List, Optional, Set, Tuple, TYPE_CHECKING + +import geohash +import numpy as np +import pandas as pd +import polyline +import simplejson as json +from dateutil import relativedelta as rdelta +from flask import request +from flask_babel import lazy_gettext as _ +from geopy.point import Point +from markdown import markdown +from pandas.tseries.frequencies import to_offset + +from superset import app, cache, get_manifest_files, security_manager +from superset.constants import NULL_STRING +from superset.exceptions import NullValueException, SpatialException +from superset.models.helpers import QueryResult +from superset.typing import VizData +from superset.utils import core as utils +from superset.utils.core import ( + DTTM_ALIAS, + JS_MAX_INTEGER, + merge_extra_filters, + to_adhoc, +) + +if TYPE_CHECKING: + from superset.connectors.base.models import BaseDatasource + +config = app.config +stats_logger = config["STATS_LOGGER"] +relative_start = config["DEFAULT_RELATIVE_START_TIME"] +relative_end = config["DEFAULT_RELATIVE_END_TIME"] +logger = logging.getLogger(__name__) + +METRIC_KEYS = [ + "metric", + "metrics", + "percent_metrics", + "metric_2", + "secondary_metric", + "x", + "y", + "size", +] + +COLUMN_FORM_DATA_PARAMS = [ + "all_columns", + "all_columns_x", + "all_columns_y", + "columns", + "dimension", + "entity", + "geojson", + "groupby", + "series", + "line_column", + "js_columns", +] + +SPATIAL_COLUMN_FORM_DATA_PARAMS = ["spatial", "start_spatial", "end_spatial"] + + +class BaseViz: + + """All visualizations derive this base class""" + + viz_type: Optional[str] = None + verbose_name = "Base Viz" + credits = "" + is_timeseries = False + cache_type = "df" + enforce_numerical_metrics = True + + def __init__( + self, + datasource: "BaseDatasource", + form_data: Dict[str, Any], + force: bool = False, + ): + if not datasource: + raise Exception(_("Viz is missing a datasource")) + + self.datasource = datasource + self.request = request + self.viz_type = form_data.get("viz_type") + self.form_data = form_data + + self.query = "" + self.token = self.form_data.get("token", "token_" + uuid.uuid4().hex[:8]) + + # merge all selectable columns into `columns` property + self.columns: List[str] = [] + for key in COLUMN_FORM_DATA_PARAMS: + value = self.form_data.get(key) or [] + value_list = value if isinstance(value, list) else [value] + if value_list: + logger.warning( + f"The form field %s is deprecated. Viz plugins should " + f"pass all selectables via the columns field", + key, + ) + self.columns += value_list + + for key in SPATIAL_COLUMN_FORM_DATA_PARAMS: + spatial = self.form_data.get(key) + if not isinstance(spatial, dict): + continue + logger.warning( + f"The form field %s is deprecated. Viz plugins should " + f"pass all selectables via the columns field", + key, + ) + if spatial.get("type") == "latlong": + self.columns += [spatial["lonCol"], spatial["latCol"]] + elif spatial.get("type") == "delimited": + self.columns.append(spatial["lonlatCol"]) + elif spatial.get("type") == "geohash": + self.columns.append(spatial["geohashCol"]) + + self.time_shift = timedelta() + + self.status: Optional[str] = None + self.error_msg = "" + self.results: Optional[QueryResult] = None + self.error_message: Optional[str] = None + self.force = force + self.from_ddtm: Optional[datetime] = None + self.to_dttm: Optional[datetime] = None + + # Keeping track of whether some data came from cache + # this is useful to trigger the when + # in the cases where visualization have many queries + # (FilterBox for instance) + self._any_cache_key: Optional[str] = None + self._any_cached_dttm: Optional[str] = None + self._extra_chart_data: List[Tuple[str, pd.DataFrame]] = [] + + self.process_metrics() + + def process_metrics(self): + # metrics in TableViz is order sensitive, so metric_dict should be + # OrderedDict + self.metric_dict = OrderedDict() + fd = self.form_data + for mkey in METRIC_KEYS: + val = fd.get(mkey) + if val: + if not isinstance(val, list): + val = [val] + for o in val: + label = utils.get_metric_name(o) + self.metric_dict[label] = o + + # Cast to list needed to return serializable object in py3 + self.all_metrics = list(self.metric_dict.values()) + self.metric_labels = list(self.metric_dict.keys()) + + @staticmethod + def handle_js_int_overflow(data): + for d in data.get("records", dict()): + for k, v in list(d.items()): + if isinstance(v, int): + # if an int is too big for Java Script to handle + # convert it to a string + if abs(v) > JS_MAX_INTEGER: + d[k] = str(v) + return data + + def run_extra_queries(self): + """Lifecycle method to use when more than one query is needed + + In rare-ish cases, a visualization may need to execute multiple + queries. That is the case for FilterBox or for time comparison + in Line chart for instance. + + In those cases, we need to make sure these queries run before the + main `get_payload` method gets called, so that the overall caching + metadata can be right. The way it works here is that if any of + the previous `get_df_payload` calls hit the cache, the main + payload's metadata will reflect that. + + The multi-query support may need more work to become a first class + use case in the framework, and for the UI to reflect the subtleties + (show that only some of the queries were served from cache for + instance). In the meantime, since multi-query is rare, we treat + it with a bit of a hack. Note that the hack became necessary + when moving from caching the visualization's data itself, to caching + the underlying query(ies). + """ + pass + + def apply_rolling(self, df): + fd = self.form_data + rolling_type = fd.get("rolling_type") + rolling_periods = int(fd.get("rolling_periods") or 0) + min_periods = int(fd.get("min_periods") or 0) + + if rolling_type in ("mean", "std", "sum") and rolling_periods: + kwargs = dict(window=rolling_periods, min_periods=min_periods) + if rolling_type == "mean": + df = df.rolling(**kwargs).mean() + elif rolling_type == "std": + df = df.rolling(**kwargs).std() + elif rolling_type == "sum": + df = df.rolling(**kwargs).sum() + elif rolling_type == "cumsum": + df = df.cumsum() + if min_periods: + df = df[min_periods:] + return df + + def get_samples(self): + query_obj = self.query_obj() + query_obj.update( + { + "metrics": [], + "row_limit": 1000, + "columns": [o.column_name for o in self.datasource.columns], + } + ) + df = self.get_df(query_obj) + return df.to_dict(orient="records") + + def get_df(self, query_obj: Optional[Dict[str, Any]] = None) -> pd.DataFrame: + """Returns a pandas dataframe based on the query object""" + if not query_obj: + query_obj = self.query_obj() + if not query_obj: + return pd.DataFrame() + + self.error_msg = "" + + timestamp_format = None + if self.datasource.type == "table": + granularity_col = self.datasource.get_column(query_obj["granularity"]) + if granularity_col: + timestamp_format = granularity_col.python_date_format + + # The datasource here can be different backend but the interface is common + self.results = self.datasource.query(query_obj) + self.query = self.results.query + self.status = self.results.status + self.error_message = self.results.error_message + + df = self.results.df + # Transform the timestamp we received from database to pandas supported + # datetime format. If no python_date_format is specified, the pattern will + # be considered as the default ISO date format + # If the datetime format is unix, the parse will use the corresponding + # parsing logic. + if not df.empty: + if DTTM_ALIAS in df.columns: + if timestamp_format in ("epoch_s", "epoch_ms"): + # Column has already been formatted as a timestamp. + dttm_col = df[DTTM_ALIAS] + one_ts_val = dttm_col[0] + + # convert time column to pandas Timestamp, but different + # ways to convert depending on string or int types + try: + int(one_ts_val) + is_integral = True + except (ValueError, TypeError): + is_integral = False + if is_integral: + unit = "s" if timestamp_format == "epoch_s" else "ms" + df[DTTM_ALIAS] = pd.to_datetime( + dttm_col, utc=False, unit=unit, origin="unix" + ) + else: + df[DTTM_ALIAS] = dttm_col.apply(pd.Timestamp) + else: + df[DTTM_ALIAS] = pd.to_datetime( + df[DTTM_ALIAS], utc=False, format=timestamp_format + ) + if self.datasource.offset: + df[DTTM_ALIAS] += timedelta(hours=self.datasource.offset) + df[DTTM_ALIAS] += self.time_shift + + if self.enforce_numerical_metrics: + self.df_metrics_to_num(df) + + df.replace([np.inf, -np.inf], np.nan, inplace=True) + return df + + def df_metrics_to_num(self, df): + """Converting metrics to numeric when pandas.read_sql cannot""" + metrics = self.metric_labels + for col, dtype in df.dtypes.items(): + if dtype.type == np.object_ and col in metrics: + df[col] = pd.to_numeric(df[col], errors="coerce") + + def process_query_filters(self): + utils.convert_legacy_filters_into_adhoc(self.form_data) + merge_extra_filters(self.form_data) + utils.split_adhoc_filters_into_base_filters(self.form_data) + + def query_obj(self) -> Dict[str, Any]: + """Building a query object""" + form_data = self.form_data + self.process_query_filters() + metrics = self.all_metrics or [] + columns = self.columns + + is_timeseries = self.is_timeseries + if DTTM_ALIAS in columns: + columns.remove(DTTM_ALIAS) + is_timeseries = True + + granularity = form_data.get("granularity") or form_data.get("granularity_sqla") + limit = int(form_data.get("limit") or 0) + timeseries_limit_metric = form_data.get("timeseries_limit_metric") + row_limit = int(form_data.get("row_limit") or config["ROW_LIMIT"]) + + # default order direction + order_desc = form_data.get("order_desc", True) + + since, until = utils.get_since_until( + relative_start=relative_start, + relative_end=relative_end, + time_range=form_data.get("time_range"), + since=form_data.get("since"), + until=form_data.get("until"), + ) + time_shift = form_data.get("time_shift", "") + self.time_shift = utils.parse_past_timedelta(time_shift) + from_dttm = None if since is None else (since - self.time_shift) + to_dttm = None if until is None else (until - self.time_shift) + if from_dttm and to_dttm and from_dttm > to_dttm: + raise Exception(_("From date cannot be larger than to date")) + + self.from_dttm = from_dttm + self.to_dttm = to_dttm + + # extras are used to query elements specific to a datasource type + # for instance the extra where clause that applies only to Tables + extras = { + "druid_time_origin": form_data.get("druid_time_origin", ""), + "having": form_data.get("having", ""), + "having_druid": form_data.get("having_filters", []), + "time_grain_sqla": form_data.get("time_grain_sqla"), + "time_range_endpoints": form_data.get("time_range_endpoints"), + "where": form_data.get("where", ""), + } + + d = { + "granularity": granularity, + "from_dttm": from_dttm, + "to_dttm": to_dttm, + "is_timeseries": is_timeseries, + "columns": columns, + "metrics": metrics, + "row_limit": row_limit, + "filter": self.form_data.get("filters", []), + "timeseries_limit": limit, + "extras": extras, + "timeseries_limit_metric": timeseries_limit_metric, + "order_desc": order_desc, + } + return d + + @property + def cache_timeout(self): + if self.form_data.get("cache_timeout") is not None: + return int(self.form_data.get("cache_timeout")) + if self.datasource.cache_timeout is not None: + return self.datasource.cache_timeout + if ( + hasattr(self.datasource, "database") + and self.datasource.database.cache_timeout + ) is not None: + return self.datasource.database.cache_timeout + return config["CACHE_DEFAULT_TIMEOUT"] + + def get_json(self): + return json.dumps( + self.get_payload(), default=utils.json_int_dttm_ser, ignore_nan=True + ) + + def cache_key(self, query_obj, **extra): + """ + The cache key is made out of the key/values in `query_obj`, plus any + other key/values in `extra`. + + We remove datetime bounds that are hard values, and replace them with + the use-provided inputs to bounds, which may be time-relative (as in + "5 days ago" or "now"). + + The `extra` arguments are currently used by time shift queries, since + different time shifts wil differ only in the `from_dttm` and `to_dttm` + values which are stripped. + """ + cache_dict = copy.copy(query_obj) + cache_dict.update(extra) + + for k in ["from_dttm", "to_dttm"]: + del cache_dict[k] + + cache_dict["time_range"] = self.form_data.get("time_range") + cache_dict["datasource"] = self.datasource.uid + cache_dict["extra_cache_keys"] = self.datasource.get_extra_cache_keys(query_obj) + cache_dict["rls"] = security_manager.get_rls_ids(self.datasource) + cache_dict["changed_on"] = self.datasource.changed_on + json_data = self.json_dumps(cache_dict, sort_keys=True) + return hashlib.md5(json_data.encode("utf-8")).hexdigest() + + def get_payload(self, query_obj=None): + """Returns a payload of metadata and data""" + self.run_extra_queries() + payload = self.get_df_payload(query_obj) + + df = payload.get("df") + if self.status != utils.QueryStatus.FAILED: + payload["data"] = self.get_data(df) + if "df" in payload: + del payload["df"] + return payload + + def get_df_payload(self, query_obj=None, **kwargs): + """Handles caching around the df payload retrieval""" + if not query_obj: + query_obj = self.query_obj() + cache_key = self.cache_key(query_obj, **kwargs) if query_obj else None + logger.info("Cache key: {}".format(cache_key)) + is_loaded = False + stacktrace = None + df = None + cached_dttm = datetime.utcnow().isoformat().split(".")[0] + if cache_key and cache and not self.force: + cache_value = cache.get(cache_key) + if cache_value: + stats_logger.incr("loading_from_cache") + try: + cache_value = pkl.loads(cache_value) + df = cache_value["df"] + self.query = cache_value["query"] + self._any_cached_dttm = cache_value["dttm"] + self._any_cache_key = cache_key + self.status = utils.QueryStatus.SUCCESS + is_loaded = True + stats_logger.incr("loaded_from_cache") + except Exception as ex: + logger.exception(ex) + logger.error( + "Error reading cache: " + utils.error_msg_from_exception(ex) + ) + logger.info("Serving from cache") + + if query_obj and not is_loaded: + try: + df = self.get_df(query_obj) + if self.status != utils.QueryStatus.FAILED: + stats_logger.incr("loaded_from_source") + if not self.force: + stats_logger.incr("loaded_from_source_without_force") + is_loaded = True + except Exception as ex: + logger.exception(ex) + if not self.error_message: + self.error_message = "{}".format(ex) + self.status = utils.QueryStatus.FAILED + stacktrace = utils.get_stacktrace() + + if ( + is_loaded + and cache_key + and cache + and self.status != utils.QueryStatus.FAILED + ): + try: + cache_value = dict(dttm=cached_dttm, df=df, query=self.query) + cache_value = pkl.dumps(cache_value, protocol=pkl.HIGHEST_PROTOCOL) + + logger.info( + "Caching {} chars at key {}".format(len(cache_value), cache_key) + ) + + stats_logger.incr("set_cache_key") + cache.set(cache_key, cache_value, timeout=self.cache_timeout) + except Exception as ex: + # cache.set call can fail if the backend is down or if + # the key is too large or whatever other reasons + logger.warning("Could not cache key {}".format(cache_key)) + logger.exception(ex) + cache.delete(cache_key) + + return { + "cache_key": self._any_cache_key, + "cached_dttm": self._any_cached_dttm, + "cache_timeout": self.cache_timeout, + "df": df, + "error": self.error_message, + "form_data": self.form_data, + "is_cached": self._any_cache_key is not None, + "query": self.query, + "from_dttm": self.from_dttm, + "to_dttm": self.to_dttm, + "status": self.status, + "stacktrace": stacktrace, + "rowcount": len(df.index) if df is not None else 0, + } + + def json_dumps(self, obj, sort_keys=False): + return json.dumps( + obj, default=utils.json_int_dttm_ser, ignore_nan=True, sort_keys=sort_keys + ) + + def payload_json_and_has_error(self, payload): + has_error = ( + payload.get("status") == utils.QueryStatus.FAILED + or payload.get("error") is not None + ) + return self.json_dumps(payload), has_error + + @property + def data(self): + """This is the data object serialized to the js layer""" + content = { + "form_data": self.form_data, + "token": self.token, + "viz_name": self.viz_type, + "filter_select_enabled": self.datasource.filter_select_enabled, + } + return content + + def get_csv(self): + df = self.get_df() + include_index = not isinstance(df.index, pd.RangeIndex) + return df.to_csv(index=include_index, **config["CSV_EXPORT"]) + + def get_data(self, df: pd.DataFrame) -> VizData: + return df.to_dict(orient="records") + + @property + def json_data(self): + return json.dumps(self.data) + + +class TableViz(BaseViz): + + """A basic html table that is sortable and searchable""" + + viz_type = "table" + verbose_name = _("Table View") + credits = 'a Superset original' + is_timeseries = False + enforce_numerical_metrics = False + + def should_be_timeseries(self): + fd = self.form_data + # TODO handle datasource-type-specific code in datasource + conditions_met = (fd.get("granularity") and fd.get("granularity") != "all") or ( + fd.get("granularity_sqla") and fd.get("time_grain_sqla") + ) + if fd.get("include_time") and not conditions_met: + raise Exception( + _("Pick a granularity in the Time section or " "uncheck 'Include Time'") + ) + return fd.get("include_time") + + def query_obj(self): + d = super().query_obj() + fd = self.form_data + + if fd.get("all_columns") and ( + fd.get("groupby") or fd.get("metrics") or fd.get("percent_metrics") + ): + raise Exception( + _( + "Choose either fields to [Group By] and [Metrics] and/or " + "[Percentage Metrics], or [Columns], not both" + ) + ) + + sort_by = fd.get("timeseries_limit_metric") + if fd.get("all_columns"): + order_by_cols = fd.get("order_by_cols") or [] + d["orderby"] = [json.loads(t) for t in order_by_cols] + elif sort_by: + sort_by_label = utils.get_metric_name(sort_by) + if sort_by_label not in utils.get_metric_names(d["metrics"]): + d["metrics"] += [sort_by] + d["orderby"] = [(sort_by, not fd.get("order_desc", True))] + + # Add all percent metrics that are not already in the list + if "percent_metrics" in fd: + d["metrics"].extend( + m for m in fd["percent_metrics"] or [] if m not in d["metrics"] + ) + + d["is_timeseries"] = self.should_be_timeseries() + return d + + def get_data(self, df: pd.DataFrame) -> VizData: + """ + Transform the query result to the table representation. + + :param df: The interim dataframe + :returns: The table visualization data + + The interim dataframe comprises of the group-by and non-group-by columns and + the union of the metrics representing the non-percent and percent metrics. Note + the percent metrics have yet to be transformed. + """ + + non_percent_metric_columns = [] + # Transform the data frame to adhere to the UI ordering of the columns and + # metrics whilst simultaneously computing the percentages (via normalization) + # for the percent metrics. + + if DTTM_ALIAS in df: + if self.should_be_timeseries(): + non_percent_metric_columns.append(DTTM_ALIAS) + else: + del df[DTTM_ALIAS] + + non_percent_metric_columns.extend( + self.form_data.get("all_columns") or self.form_data.get("groupby") or [] + ) + + non_percent_metric_columns.extend( + utils.get_metric_names(self.form_data.get("metrics") or []) + ) + + timeseries_limit_metric = utils.get_metric_name( + self.form_data.get("timeseries_limit_metric") + ) + if timeseries_limit_metric: + non_percent_metric_columns.append(timeseries_limit_metric) + + percent_metric_columns = utils.get_metric_names( + self.form_data.get("percent_metrics") or [] + ) + + if not df.empty: + df = pd.concat( + [ + df[non_percent_metric_columns], + ( + df[percent_metric_columns] + .div(df[percent_metric_columns].sum()) + .add_prefix("%") + ), + ], + axis=1, + ) + + data = self.handle_js_int_overflow( + dict(records=df.to_dict(orient="records"), columns=list(df.columns)) + ) + + return data + + def json_dumps(self, obj, sort_keys=False): + return json.dumps( + obj, default=utils.json_iso_dttm_ser, sort_keys=sort_keys, ignore_nan=True + ) + + +class TimeTableViz(BaseViz): + + """A data table with rich time-series related columns""" + + viz_type = "time_table" + verbose_name = _("Time Table View") + credits = 'a Superset original' + is_timeseries = True + + def query_obj(self): + d = super().query_obj() + fd = self.form_data + + if not fd.get("metrics"): + raise Exception(_("Pick at least one metric")) + + if fd.get("groupby") and len(fd.get("metrics")) > 1: + raise Exception( + _("When using 'Group By' you are limited to use a single metric") + ) + return d + + def get_data(self, df: pd.DataFrame) -> VizData: + if df.empty: + return None + + fd = self.form_data + columns = None + values = self.metric_labels + if fd.get("groupby"): + values = self.metric_labels[0] + columns = fd.get("groupby") + pt = df.pivot_table(index=DTTM_ALIAS, columns=columns, values=values) + pt.index = pt.index.map(str) + pt = pt.sort_index() + return dict( + records=pt.to_dict(orient="index"), + columns=list(pt.columns), + is_group_by=len(fd.get("groupby", [])) > 0, + ) + + +class PivotTableViz(BaseViz): + + """A pivot table view, define your rows, columns and metrics""" + + viz_type = "pivot_table" + verbose_name = _("Pivot Table") + credits = 'a Superset original' + is_timeseries = False + + def query_obj(self): + d = super().query_obj() + groupby = self.form_data.get("groupby") + columns = self.form_data.get("columns") + metrics = self.form_data.get("metrics") + transpose = self.form_data.get("transpose_pivot") + if not columns: + columns = [] + if not groupby: + groupby = [] + if not groupby: + raise Exception(_("Please choose at least one 'Group by' field ")) + if transpose and not columns: + raise Exception( + _( + ( + "Please choose at least one 'Columns' field when " + "select 'Transpose Pivot' option" + ) + ) + ) + if not metrics: + raise Exception(_("Please choose at least one metric")) + if set(groupby) & set(columns): + raise Exception(_("Group By' and 'Columns' can't overlap")) + return d + + def get_data(self, df: pd.DataFrame) -> VizData: + if df.empty: + return None + + if self.form_data.get("granularity") == "all" and DTTM_ALIAS in df: + del df[DTTM_ALIAS] + + aggfunc = self.form_data.get("pandas_aggfunc") or "sum" + + # Ensure that Pandas's sum function mimics that of SQL. + if aggfunc == "sum": + aggfunc = lambda x: x.sum(min_count=1) + + groupby = self.form_data.get("groupby") + columns = self.form_data.get("columns") + if self.form_data.get("transpose_pivot"): + groupby, columns = columns, groupby + metrics = [utils.get_metric_name(m) for m in self.form_data["metrics"]] + df = df.pivot_table( + index=groupby, + columns=columns, + values=metrics, + aggfunc=aggfunc, + margins=self.form_data.get("pivot_margins"), + ) + + # Re-order the columns adhering to the metric ordering. + df = df[metrics] + + # Display metrics side by side with each column + if self.form_data.get("combine_metric"): + df = df.stack(0).unstack() + return dict( + columns=list(df.columns), + html=df.to_html( + na_rep="null", + classes=( + "dataframe table table-striped table-bordered " + "table-condensed table-hover" + ).split(" "), + ), + ) + + +class MarkupViz(BaseViz): + + """Use html or markdown to create a free form widget""" + + viz_type = "markup" + verbose_name = _("Markup") + is_timeseries = False + + def query_obj(self): + return None + + def get_df(self, query_obj: Optional[Dict[str, Any]] = None) -> pd.DataFrame: + return pd.DataFrame() + + def get_data(self, df: pd.DataFrame) -> VizData: + markup_type = self.form_data.get("markup_type") + code = self.form_data.get("code", "") + if markup_type == "markdown": + code = markdown(code) + return dict(html=code, theme_css=get_manifest_files("theme", "css")) + + +class SeparatorViz(MarkupViz): + + """Use to create section headers in a dashboard, similar to `Markup`""" + + viz_type = "separator" + verbose_name = _("Separator") + + +class WordCloudViz(BaseViz): + + """Build a colorful word cloud + + Uses the nice library at: + https://github.com/jasondavies/d3-cloud + """ + + viz_type = "word_cloud" + verbose_name = _("Word Cloud") + is_timeseries = False + + +class TreemapViz(BaseViz): + + """Tree map visualisation for hierarchical data.""" + + viz_type = "treemap" + verbose_name = _("Treemap") + credits = 'd3.js' + is_timeseries = False + + def _nest(self, metric, df): + nlevels = df.index.nlevels + if nlevels == 1: + result = [{"name": n, "value": v} for n, v in zip(df.index, df[metric])] + else: + result = [ + {"name": l, "children": self._nest(metric, df.loc[l])} + for l in df.index.levels[0] + ] + return result + + def get_data(self, df: pd.DataFrame) -> VizData: + if df.empty: + return None + + df = df.set_index(self.form_data.get("groupby")) + chart_data = [ + {"name": metric, "children": self._nest(metric, df)} + for metric in df.columns + ] + return chart_data + + +class CalHeatmapViz(BaseViz): + + """Calendar heatmap.""" + + viz_type = "cal_heatmap" + verbose_name = _("Calendar Heatmap") + credits = "cal-heatmap" + is_timeseries = True + + def get_data(self, df: pd.DataFrame) -> VizData: + form_data = self.form_data + + data = {} + records = df.to_dict("records") + for metric in self.metric_labels: + values = {} + for obj in records: + v = obj[DTTM_ALIAS] + if hasattr(v, "value"): + v = v.value + values[str(v / 10 ** 9)] = obj.get(metric) + data[metric] = values + + start, end = utils.get_since_until( + relative_start=relative_start, + relative_end=relative_end, + time_range=form_data.get("time_range"), + since=form_data.get("since"), + until=form_data.get("until"), + ) + if not start or not end: + raise Exception("Please provide both time bounds (Since and Until)") + domain = form_data.get("domain_granularity") + diff_delta = rdelta.relativedelta(end, start) + diff_secs = (end - start).total_seconds() + + if domain == "year": + range_ = diff_delta.years + 1 + elif domain == "month": + range_ = diff_delta.years * 12 + diff_delta.months + 1 + elif domain == "week": + range_ = diff_delta.years * 53 + diff_delta.weeks + 1 + elif domain == "day": + range_ = diff_secs // (24 * 60 * 60) + 1 # type: ignore + else: + range_ = diff_secs // (60 * 60) + 1 # type: ignore + + return { + "data": data, + "start": start, + "domain": domain, + "subdomain": form_data.get("subdomain_granularity"), + "range": range_, + } + + def query_obj(self): + d = super().query_obj() + fd = self.form_data + d["metrics"] = fd.get("metrics") + return d + + +class NVD3Viz(BaseViz): + + """Base class for all nvd3 vizs""" + + credits = 'NVD3.org' + viz_type: Optional[str] = None + verbose_name = "Base NVD3 Viz" + is_timeseries = False + + +class BoxPlotViz(NVD3Viz): + + """Box plot viz from ND3""" + + viz_type = "box_plot" + verbose_name = _("Box Plot") + sort_series = False + is_timeseries = True + + def to_series(self, df, classed="", title_suffix=""): + label_sep = " - " + chart_data = [] + for index_value, row in zip(df.index, df.to_dict(orient="records")): + if isinstance(index_value, tuple): + index_value = label_sep.join(index_value) + boxes = defaultdict(dict) + for (label, key), value in row.items(): + if key == "nanmedian": + key = "Q2" + boxes[label][key] = value + for label, box in boxes.items(): + if len(self.form_data.get("metrics")) > 1: + # need to render data labels with metrics + chart_label = label_sep.join([index_value, label]) + else: + chart_label = index_value + chart_data.append({"label": chart_label, "values": box}) + return chart_data + + def get_data(self, df: pd.DataFrame) -> VizData: + if df.empty: + return None + + form_data = self.form_data + + # conform to NVD3 names + def Q1(series): # need to be named functions - can't use lambdas + return np.nanpercentile(series, 25) + + def Q3(series): + return np.nanpercentile(series, 75) + + whisker_type = form_data.get("whisker_options") + if whisker_type == "Tukey": + + def whisker_high(series): + upper_outer_lim = Q3(series) + 1.5 * (Q3(series) - Q1(series)) + return series[series <= upper_outer_lim].max() + + def whisker_low(series): + lower_outer_lim = Q1(series) - 1.5 * (Q3(series) - Q1(series)) + return series[series >= lower_outer_lim].min() + + elif whisker_type == "Min/max (no outliers)": + + def whisker_high(series): + return series.max() + + def whisker_low(series): + return series.min() + + elif " percentiles" in whisker_type: # type: ignore + low, high = whisker_type.replace(" percentiles", "").split( # type: ignore + "/" + ) + + def whisker_high(series): + return np.nanpercentile(series, int(high)) + + def whisker_low(series): + return np.nanpercentile(series, int(low)) + + else: + raise ValueError("Unknown whisker type: {}".format(whisker_type)) + + def outliers(series): + above = series[series > whisker_high(series)] + below = series[series < whisker_low(series)] + # pandas sometimes doesn't like getting lists back here + return set(above.tolist() + below.tolist()) + + aggregate = [Q1, np.nanmedian, Q3, whisker_high, whisker_low, outliers] + df = df.groupby(form_data.get("groupby")).agg(aggregate) + chart_data = self.to_series(df) + return chart_data + + +class BubbleViz(NVD3Viz): + + """Based on the NVD3 bubble chart""" + + viz_type = "bubble" + verbose_name = _("Bubble Chart") + is_timeseries = False + + def query_obj(self): + form_data = self.form_data + d = super().query_obj() + + self.x_metric = form_data.get("x") + self.y_metric = form_data.get("y") + self.z_metric = form_data.get("size") + self.entity = form_data.get("entity") + self.series = form_data.get("series") or self.entity + d["row_limit"] = form_data.get("limit") + + d["metrics"] = [self.z_metric, self.x_metric, self.y_metric] + if len(set(self.metric_labels)) < 3: + raise Exception(_("Please use 3 different metric labels")) + if not all(d["metrics"] + [self.entity]): + raise Exception(_("Pick a metric for x, y and size")) + return d + + def get_data(self, df: pd.DataFrame) -> VizData: + if df.empty: + return None + + df["x"] = df[[utils.get_metric_name(self.x_metric)]] + df["y"] = df[[utils.get_metric_name(self.y_metric)]] + df["size"] = df[[utils.get_metric_name(self.z_metric)]] + df["shape"] = "circle" + df["group"] = df[[self.series]] + + series: Dict[Any, List[Any]] = defaultdict(list) + for row in df.to_dict(orient="records"): + series[row["group"]].append(row) + chart_data = [] + for k, v in series.items(): + chart_data.append({"key": k, "values": v}) + return chart_data + + +class BulletViz(NVD3Viz): + + """Based on the NVD3 bullet chart""" + + viz_type = "bullet" + verbose_name = _("Bullet Chart") + is_timeseries = False + + def query_obj(self): + form_data = self.form_data + d = super().query_obj() + self.metric = form_data.get("metric") + + def as_strings(field): + value = form_data.get(field) + return value.split(",") if value else [] + + def as_floats(field): + return [float(x) for x in as_strings(field)] + + self.ranges = as_floats("ranges") + self.range_labels = as_strings("range_labels") + self.markers = as_floats("markers") + self.marker_labels = as_strings("marker_labels") + self.marker_lines = as_floats("marker_lines") + self.marker_line_labels = as_strings("marker_line_labels") + + d["metrics"] = [self.metric] + if not self.metric: + raise Exception(_("Pick a metric to display")) + return d + + def get_data(self, df: pd.DataFrame) -> VizData: + df["metric"] = df[[utils.get_metric_name(self.metric)]] + values = df["metric"].values + return { + "measures": values.tolist(), + "ranges": self.ranges or [0, values.max() * 1.1], + "rangeLabels": self.range_labels or None, + "markers": self.markers or None, + "markerLabels": self.marker_labels or None, + "markerLines": self.marker_lines or None, + "markerLineLabels": self.marker_line_labels or None, + } + + +class BigNumberViz(BaseViz): + + """Put emphasis on a single metric with this big number viz""" + + viz_type = "big_number" + verbose_name = _("Big Number with Trendline") + credits = 'a Superset original' + is_timeseries = True + + def query_obj(self): + d = super().query_obj() + metric = self.form_data.get("metric") + if not metric: + raise Exception(_("Pick a metric!")) + d["metrics"] = [self.form_data.get("metric")] + self.form_data["metric"] = metric + return d + + def get_data(self, df: pd.DataFrame) -> VizData: + df = df.pivot_table( + index=DTTM_ALIAS, + columns=[], + values=self.metric_labels, + dropna=False, + aggfunc=np.min, # looking for any (only) value, preserving `None` + ) + df = self.apply_rolling(df) + df[DTTM_ALIAS] = df.index + return super().get_data(df) + + +class BigNumberTotalViz(BaseViz): + + """Put emphasis on a single metric with this big number viz""" + + viz_type = "big_number_total" + verbose_name = _("Big Number") + credits = 'a Superset original' + is_timeseries = False + + def query_obj(self): + d = super().query_obj() + metric = self.form_data.get("metric") + if not metric: + raise Exception(_("Pick a metric!")) + d["metrics"] = [self.form_data.get("metric")] + self.form_data["metric"] = metric + + # Limiting rows is not required as only one cell is returned + d["row_limit"] = None + return d + + +class NVD3TimeSeriesViz(NVD3Viz): + + """A rich line chart component with tons of options""" + + viz_type = "line" + verbose_name = _("Time Series - Line Chart") + sort_series = False + is_timeseries = True + pivot_fill_value: Optional[int] = None + + def to_series(self, df, classed="", title_suffix=""): + cols = [] + for col in df.columns: + if col == "": + cols.append("N/A") + elif col is None: + cols.append("NULL") + else: + cols.append(col) + df.columns = cols + series = df.to_dict("series") + + chart_data = [] + for name in df.T.index.tolist(): + ys = series[name] + if df[name].dtype.kind not in "biufc": + continue + if isinstance(name, list): + series_title = [str(title) for title in name] + elif isinstance(name, tuple): + series_title = tuple(str(title) for title in name) + else: + series_title = str(name) + if ( + isinstance(series_title, (list, tuple)) + and len(series_title) > 1 + and len(self.metric_labels) == 1 + ): + # Removing metric from series name if only one metric + series_title = series_title[1:] + if title_suffix: + if isinstance(series_title, str): + series_title = (series_title, title_suffix) + elif isinstance(series_title, (list, tuple)): + series_title = series_title + (title_suffix,) + + values = [] + non_nan_cnt = 0 + for ds in df.index: + if ds in ys: + d = {"x": ds, "y": ys[ds]} + if not np.isnan(ys[ds]): + non_nan_cnt += 1 + else: + d = {} + values.append(d) + + if non_nan_cnt == 0: + continue + + d = {"key": series_title, "values": values} + if classed: + d["classed"] = classed + chart_data.append(d) + return chart_data + + def process_data(self, df: pd.DataFrame, aggregate: bool = False) -> VizData: + fd = self.form_data + if fd.get("granularity") == "all": + raise Exception(_("Pick a time granularity for your time series")) + + if df.empty: + return df + + if aggregate: + df = df.pivot_table( + index=DTTM_ALIAS, + columns=self.columns, + values=self.metric_labels, + fill_value=0, + aggfunc=sum, + ) + else: + df = df.pivot_table( + index=DTTM_ALIAS, + columns=self.columns, + values=self.metric_labels, + fill_value=self.pivot_fill_value, + ) + + rule = fd.get("resample_rule") + method = fd.get("resample_method") + + if rule and method: + df = getattr(df.resample(rule), method)() + + if self.sort_series: + dfs = df.sum() + dfs.sort_values(ascending=False, inplace=True) + df = df[dfs.index] + + df = self.apply_rolling(df) + if fd.get("contribution"): + dft = df.T + df = (dft / dft.sum()).T + + return df + + def run_extra_queries(self): + fd = self.form_data + + time_compare = fd.get("time_compare") or [] + # backwards compatibility + if not isinstance(time_compare, list): + time_compare = [time_compare] + + for option in time_compare: + query_object = self.query_obj() + delta = utils.parse_past_timedelta(option) + query_object["inner_from_dttm"] = query_object["from_dttm"] + query_object["inner_to_dttm"] = query_object["to_dttm"] + + if not query_object["from_dttm"] or not query_object["to_dttm"]: + raise Exception( + _( + "`Since` and `Until` time bounds should be specified " + "when using the `Time Shift` feature." + ) + ) + query_object["from_dttm"] -= delta + query_object["to_dttm"] -= delta + + df2 = self.get_df_payload(query_object, time_compare=option).get("df") + if df2 is not None and DTTM_ALIAS in df2: + label = "{} offset".format(option) + df2[DTTM_ALIAS] += delta + df2 = self.process_data(df2) + self._extra_chart_data.append((label, df2)) + + def get_data(self, df: pd.DataFrame) -> VizData: + fd = self.form_data + comparison_type = fd.get("comparison_type") or "values" + df = self.process_data(df) + if comparison_type == "values": + # Filter out series with all NaN + chart_data = self.to_series(df.dropna(axis=1, how="all")) + + for i, (label, df2) in enumerate(self._extra_chart_data): + chart_data.extend( + self.to_series( + df2, classed="time-shift-{}".format(i), title_suffix=label + ) + ) + else: + chart_data = [] + for i, (label, df2) in enumerate(self._extra_chart_data): + # reindex df2 into the df2 index + combined_index = df.index.union(df2.index) + df2 = ( + df2.reindex(combined_index) + .interpolate(method="time") + .reindex(df.index) + ) + + if comparison_type == "absolute": + diff = df - df2 + elif comparison_type == "percentage": + diff = (df - df2) / df2 + elif comparison_type == "ratio": + diff = df / df2 + else: + raise Exception( + "Invalid `comparison_type`: {0}".format(comparison_type) + ) + + # remove leading/trailing NaNs from the time shift difference + diff = diff[diff.first_valid_index() : diff.last_valid_index()] + + chart_data.extend( + self.to_series( + diff, classed="time-shift-{}".format(i), title_suffix=label + ) + ) + + if not self.sort_series: + chart_data = sorted(chart_data, key=lambda x: tuple(x["key"])) + return chart_data + + +class MultiLineViz(NVD3Viz): + + """Pile on multiple line charts""" + + viz_type = "line_multi" + verbose_name = _("Time Series - Multiple Line Charts") + + is_timeseries = True + + def query_obj(self): + return None + + def get_data(self, df: pd.DataFrame) -> VizData: + fd = self.form_data + # Late imports to avoid circular import issues + from superset.models.slice import Slice + from superset import db + + slice_ids1 = fd.get("line_charts") + slices1 = db.session.query(Slice).filter(Slice.id.in_(slice_ids1)).all() + slice_ids2 = fd.get("line_charts_2") + slices2 = db.session.query(Slice).filter(Slice.id.in_(slice_ids2)).all() + return { + "slices": { + "axis1": [slc.data for slc in slices1], + "axis2": [slc.data for slc in slices2], + } + } + + +class NVD3DualLineViz(NVD3Viz): + + """A rich line chart with dual axis""" + + viz_type = "dual_line" + verbose_name = _("Time Series - Dual Axis Line Chart") + sort_series = False + is_timeseries = True + + def query_obj(self): + d = super().query_obj() + m1 = self.form_data.get("metric") + m2 = self.form_data.get("metric_2") + d["metrics"] = [m1, m2] + if not m1: + raise Exception(_("Pick a metric for left axis!")) + if not m2: + raise Exception(_("Pick a metric for right axis!")) + if m1 == m2: + raise Exception( + _("Please choose different metrics" " on left and right axis") + ) + return d + + def to_series(self, df, classed=""): + cols = [] + for col in df.columns: + if col == "": + cols.append("N/A") + elif col is None: + cols.append("NULL") + else: + cols.append(col) + df.columns = cols + series = df.to_dict("series") + chart_data = [] + metrics = [self.form_data.get("metric"), self.form_data.get("metric_2")] + for i, m in enumerate(metrics): + m = utils.get_metric_name(m) + ys = series[m] + if df[m].dtype.kind not in "biufc": + continue + series_title = m + d = { + "key": series_title, + "classed": classed, + "values": [ + {"x": ds, "y": ys[ds] if ds in ys else None} for ds in df.index + ], + "yAxis": i + 1, + "type": "line", + } + chart_data.append(d) + return chart_data + + def get_data(self, df: pd.DataFrame) -> VizData: + if df.empty: + return None + + fd = self.form_data + + if self.form_data.get("granularity") == "all": + raise Exception(_("Pick a time granularity for your time series")) + + metric = utils.get_metric_name(fd.get("metric")) + metric_2 = utils.get_metric_name(fd.get("metric_2")) + df = df.pivot_table(index=DTTM_ALIAS, values=[metric, metric_2]) + + chart_data = self.to_series(df) + return chart_data + + +class NVD3TimeSeriesBarViz(NVD3TimeSeriesViz): + + """A bar chart where the x axis is time""" + + viz_type = "bar" + sort_series = True + verbose_name = _("Time Series - Bar Chart") + + +class NVD3TimePivotViz(NVD3TimeSeriesViz): + + """Time Series - Periodicity Pivot""" + + viz_type = "time_pivot" + sort_series = True + verbose_name = _("Time Series - Period Pivot") + + def query_obj(self): + d = super().query_obj() + d["metrics"] = [self.form_data.get("metric")] + return d + + def get_data(self, df: pd.DataFrame) -> VizData: + if df.empty: + return None + + fd = self.form_data + df = self.process_data(df) + freq = to_offset(fd.get("freq")) + try: + freq = type(freq)(freq.n, normalize=True, **freq.kwds) + except ValueError: + freq = type(freq)(freq.n, **freq.kwds) + df.index.name = None + df[DTTM_ALIAS] = df.index.map(freq.rollback) + df["ranked"] = df[DTTM_ALIAS].rank(method="dense", ascending=False) - 1 + df.ranked = df.ranked.map(int) + df["series"] = "-" + df.ranked.map(str) + df["series"] = df["series"].str.replace("-0", "current") + rank_lookup = { + row["series"]: row["ranked"] for row in df.to_dict(orient="records") + } + max_ts = df[DTTM_ALIAS].max() + max_rank = df["ranked"].max() + df[DTTM_ALIAS] = df.index + (max_ts - df[DTTM_ALIAS]) + df = df.pivot_table( + index=DTTM_ALIAS, + columns="series", + values=utils.get_metric_name(fd.get("metric")), + ) + chart_data = self.to_series(df) + for serie in chart_data: + serie["rank"] = rank_lookup[serie["key"]] + serie["perc"] = 1 - (serie["rank"] / (max_rank + 1)) + return chart_data + + +class NVD3CompareTimeSeriesViz(NVD3TimeSeriesViz): + + """A line chart component where you can compare the % change over time""" + + viz_type = "compare" + verbose_name = _("Time Series - Percent Change") + + +class NVD3TimeSeriesStackedViz(NVD3TimeSeriesViz): + + """A rich stack area chart""" + + viz_type = "area" + verbose_name = _("Time Series - Stacked") + sort_series = True + pivot_fill_value = 0 + + +class DistributionPieViz(NVD3Viz): + + """Annoy visualization snobs with this controversial pie chart""" + + viz_type = "pie" + verbose_name = _("Distribution - NVD3 - Pie Chart") + is_timeseries = False + + def get_data(self, df: pd.DataFrame) -> VizData: + if df.empty: + return None + metric = self.metric_labels[0] + df = df.pivot_table(index=self.columns, values=[metric]) + df.sort_values(by=metric, ascending=False, inplace=True) + df = df.reset_index() + df.columns = ["x", "y"] + return df.to_dict(orient="records") + + +class HistogramViz(BaseViz): + + """Histogram""" + + viz_type = "histogram" + verbose_name = _("Histogram") + is_timeseries = False + + def query_obj(self): + """Returns the query object for this visualization""" + d = super().query_obj() + d["row_limit"] = self.form_data.get("row_limit", int(config["VIZ_ROW_LIMIT"])) + if not self.form_data.get("all_columns_x"): + raise Exception(_("Must have at least one numeric column specified")) + return d + + def labelify(self, keys, column): + if isinstance(keys, str): + keys = (keys,) + # removing undesirable characters + labels = [re.sub(r"\W+", r"_", k) for k in keys] + if len(self.columns) > 1: + # Only show numeric column in label if there are many + labels = [column] + labels + return "__".join(labels) + + def get_data(self, df: pd.DataFrame) -> VizData: + """Returns the chart data""" + groupby = self.form_data.get("groupby") + + if df.empty: + return None + + chart_data = [] + if groupby: + groups = df.groupby(groupby) + else: + groups = [((), df)] + for keys, data in groups: + chart_data.extend( + [ + { + "key": self.labelify(keys, column), + "values": data[column].tolist(), + } + for column in self.columns + ] + ) + return chart_data + + +class DistributionBarViz(DistributionPieViz): + + """A good old bar chart""" + + viz_type = "dist_bar" + verbose_name = _("Distribution - Bar Chart") + is_timeseries = False + + def query_obj(self): + # TODO: Refactor this plugin to either perform grouping or assume + # preaggretagion of metrics ("numeric columns") + d = super().query_obj() + fd = self.form_data + if not self.all_metrics: + raise Exception(_("Pick at least one metric")) + if not self.columns: + raise Exception(_("Pick at least one field for [Series]")) + return d + + def get_data(self, df: pd.DataFrame) -> VizData: + if df.empty: + return None + + fd = self.form_data + metrics = self.metric_labels + # TODO: will require post transformation logic not currently available in + # /api/v1/query endpoint + columns = fd.get("columns") or [] + groupby = fd.get("groupby") or [] + + # pandas will throw away nulls when grouping/pivoting, + # so we substitute NULL_STRING for any nulls in the necessary columns + df[self.columns] = df[self.columns].fillna(value=NULL_STRING) + + row = df.groupby(groupby).sum()[metrics[0]].copy() + row.sort_values(ascending=False, inplace=True) + pt = df.pivot_table(index=groupby, columns=columns, values=metrics) + if fd.get("contribution"): + pt = pt.T + pt = (pt / pt.sum()).T + pt = pt.reindex(row.index) + chart_data = [] + for name, ys in pt.items(): + if pt[name].dtype.kind not in "biufc" or name in groupby: + continue + if isinstance(name, str): + series_title = name + else: + offset = 0 if len(metrics) > 1 else 1 + series_title = ", ".join([str(s) for s in name[offset:]]) + values = [] + for i, v in ys.items(): + x = i + if isinstance(x, (tuple, list)): + x = ", ".join([str(s) for s in x]) + else: + x = str(x) + values.append({"x": x, "y": v}) + d = {"key": series_title, "values": values} + chart_data.append(d) + return chart_data + + +class SunburstViz(BaseViz): + + """A multi level sunburst chart""" + + viz_type = "sunburst" + verbose_name = _("Sunburst") + is_timeseries = False + credits = ( + "Kerry Rodden " + '@bl.ocks.org' + ) + + def get_data(self, df: pd.DataFrame) -> VizData: + fd = self.form_data + cols = fd.get("groupby") or [] + cols.extend(["m1", "m2"]) + metric = utils.get_metric_name(fd.get("metric")) + secondary_metric = utils.get_metric_name(fd.get("secondary_metric")) + if metric == secondary_metric or secondary_metric is None: + df.rename(columns={df.columns[-1]: "m1"}, inplace=True) + df["m2"] = df["m1"] + else: + df.rename(columns={df.columns[-2]: "m1"}, inplace=True) + df.rename(columns={df.columns[-1]: "m2"}, inplace=True) + + # Re-order the columns as the query result set column ordering may differ from + # that listed in the hierarchy. + df = df[cols] + return df.to_numpy().tolist() + + def query_obj(self): + qry = super().query_obj() + fd = self.form_data + qry["metrics"] = [fd["metric"]] + secondary_metric = fd.get("secondary_metric") + if secondary_metric and secondary_metric != fd["metric"]: + qry["metrics"].append(secondary_metric) + return qry + + +class SankeyViz(BaseViz): + + """A Sankey diagram that requires a parent-child dataset""" + + viz_type = "sankey" + verbose_name = _("Sankey") + is_timeseries = False + credits = 'd3-sankey on npm' + + def get_data(self, df: pd.DataFrame) -> VizData: + df.columns = ["source", "target", "value"] + df["source"] = df["source"].astype(str) + df["target"] = df["target"].astype(str) + recs = df.to_dict(orient="records") + + hierarchy: Dict[str, Set[str]] = defaultdict(set) + for row in recs: + hierarchy[row["source"]].add(row["target"]) + + def find_cycle(g): + """Whether there's a cycle in a directed graph""" + path = set() + + def visit(vertex): + path.add(vertex) + for neighbour in g.get(vertex, ()): + if neighbour in path or visit(neighbour): + return (vertex, neighbour) + path.remove(vertex) + + for v in g: + cycle = visit(v) + if cycle: + return cycle + + cycle = find_cycle(hierarchy) + if cycle: + raise Exception( + _( + "There's a loop in your Sankey, please provide a tree. " + "Here's a faulty link: {}" + ).format(cycle) + ) + return recs + + +class DirectedForceViz(BaseViz): + + """An animated directed force layout graph visualization""" + + viz_type = "directed_force" + verbose_name = _("Directed Force Layout") + credits = 'd3noob @bl.ocks.org' + is_timeseries = False + + def query_obj(self): + qry = super().query_obj() + if len(self.form_data["groupby"]) != 2: + raise Exception(_("Pick exactly 2 columns to 'Group By'")) + qry["metrics"] = [self.form_data["metric"]] + return qry + + def get_data(self, df: pd.DataFrame) -> VizData: + df.columns = ["source", "target", "value"] + return df.to_dict(orient="records") + + +class ChordViz(BaseViz): + + """A Chord diagram""" + + viz_type = "chord" + verbose_name = _("Directed Force Layout") + credits = 'Bostock' + is_timeseries = False + + def query_obj(self): + qry = super().query_obj() + fd = self.form_data + qry["metrics"] = [fd.get("metric")] + return qry + + def get_data(self, df: pd.DataFrame) -> VizData: + if df.empty: + return None + + df.columns = ["source", "target", "value"] + + # Preparing a symetrical matrix like d3.chords calls for + nodes = list(set(df["source"]) | set(df["target"])) + matrix = {} + for source, target in product(nodes, nodes): + matrix[(source, target)] = 0 + for source, target, value in df.to_records(index=False): + matrix[(source, target)] = value + m = [[matrix[(n1, n2)] for n1 in nodes] for n2 in nodes] + return {"nodes": list(nodes), "matrix": m} + + +class CountryMapViz(BaseViz): + + """A country centric""" + + viz_type = "country_map" + verbose_name = _("Country Map") + is_timeseries = False + credits = "From bl.ocks.org By john-guerra" + + def get_data(self, df: pd.DataFrame) -> VizData: + fd = self.form_data + cols = [fd.get("entity")] + metric = self.metric_labels[0] + cols += [metric] + ndf = df[cols] + df = ndf + df.columns = ["country_id", "metric"] + d = df.to_dict(orient="records") + return d + + +class WorldMapViz(BaseViz): + + """A country centric world map""" + + viz_type = "world_map" + verbose_name = _("World Map") + is_timeseries = False + credits = 'datamaps on npm' + + def get_data(self, df: pd.DataFrame) -> VizData: + from superset.examples import countries + + fd = self.form_data + cols = [fd.get("entity")] + metric = utils.get_metric_name(fd.get("metric")) + secondary_metric = utils.get_metric_name(fd.get("secondary_metric")) + columns = ["country", "m1", "m2"] + if metric == secondary_metric: + ndf = df[cols] + ndf["m1"] = df[metric] + ndf["m2"] = ndf["m1"] + else: + if secondary_metric: + cols += [metric, secondary_metric] + else: + cols += [metric] + columns = ["country", "m1"] + ndf = df[cols] + df = ndf + df.columns = columns + d = df.to_dict(orient="records") + for row in d: + country = None + if isinstance(row["country"], str): + if "country_fieldtype" in fd: + country = countries.get(fd["country_fieldtype"], row["country"]) + if country: + row["country"] = country["cca3"] + row["latitude"] = country["lat"] + row["longitude"] = country["lng"] + row["name"] = country["name"] + else: + row["country"] = "XXX" + return d + + +class FilterBoxViz(BaseViz): + + """A multi filter, multi-choice filter box to make dashboards interactive""" + + viz_type = "filter_box" + verbose_name = _("Filters") + is_timeseries = False + credits = 'a Superset original' + cache_type = "get_data" + filter_row_limit = 1000 + + def query_obj(self): + return None + + def run_extra_queries(self): + qry = super().query_obj() + filters = self.form_data.get("filter_configs") or [] + qry["row_limit"] = self.filter_row_limit + self.dataframes = {} + for flt in filters: + col = flt.get("column") + if not col: + raise Exception( + _("Invalid filter configuration, please select a column") + ) + qry["columns"] = [col] + metric = flt.get("metric") + qry["metrics"] = [metric] if metric else [] + df = self.get_df_payload(query_obj=qry).get("df") + self.dataframes[col] = df + + def get_data(self, df: pd.DataFrame) -> VizData: + filters = self.form_data.get("filter_configs") or [] + d = {} + for flt in filters: + col = flt.get("column") + metric = flt.get("metric") + df = self.dataframes.get(col) + if df is not None: + if metric: + df = df.sort_values( + utils.get_metric_name(metric), ascending=flt.get("asc") + ) + d[col] = [ + {"id": row[0], "text": row[0], "metric": row[1]} + for row in df.itertuples(index=False) + ] + else: + df = df.sort_values(col, ascending=flt.get("asc")) + d[col] = [ + {"id": row[0], "text": row[0]} + for row in df.itertuples(index=False) + ] + return d + + +class IFrameViz(BaseViz): + + """You can squeeze just about anything in this iFrame component""" + + viz_type = "iframe" + verbose_name = _("iFrame") + credits = 'a Superset original' + is_timeseries = False + + def query_obj(self): + return None + + def get_df(self, query_obj: Optional[Dict[str, Any]] = None) -> pd.DataFrame: + return pd.DataFrame() + + def get_data(self, df: pd.DataFrame) -> VizData: + return {"iframe": True} + + +class ParallelCoordinatesViz(BaseViz): + + """Interactive parallel coordinate implementation + + Uses this amazing javascript library + https://github.com/syntagmatic/parallel-coordinates + """ + + viz_type = "para" + verbose_name = _("Parallel Coordinates") + credits = ( + '' + "Syntagmatic's library" + ) + is_timeseries = False + + def get_data(self, df: pd.DataFrame) -> VizData: + return df.to_dict(orient="records") + + +class HeatmapViz(BaseViz): + + """A nice heatmap visualization that support high density through canvas""" + + viz_type = "heatmap" + verbose_name = _("Heatmap") + is_timeseries = False + credits = ( + 'inspired from mbostock @' + "bl.ocks.org" + ) + + def get_data(self, df: pd.DataFrame) -> VizData: + if df.empty: + return None + + fd = self.form_data + x = fd.get("all_columns_x") + y = fd.get("all_columns_y") + v = self.metric_labels[0] + if x == y: + df.columns = ["x", "y", "v"] + else: + df = df[[x, y, v]] + df.columns = ["x", "y", "v"] + norm = fd.get("normalize_across") + overall = False + max_ = df.v.max() + min_ = df.v.min() + if norm == "heatmap": + overall = True + else: + gb = df.groupby(norm, group_keys=False) + if len(gb) <= 1: + overall = True + else: + df["perc"] = gb.apply( + lambda x: (x.v - x.v.min()) / (x.v.max() - x.v.min()) + ) + df["rank"] = gb.apply(lambda x: x.v.rank(pct=True)) + if overall: + df["perc"] = (df.v - min_) / (max_ - min_) + df["rank"] = df.v.rank(pct=True) + return {"records": df.to_dict(orient="records"), "extents": [min_, max_]} + + +class HorizonViz(NVD3TimeSeriesViz): + + """Horizon chart + + https://www.npmjs.com/package/d3-horizon-chart + """ + + viz_type = "horizon" + verbose_name = _("Horizon Charts") + credits = ( + '' + "d3-horizon-chart" + ) + + +class MapboxViz(BaseViz): + + """Rich maps made with Mapbox""" + + viz_type = "mapbox" + verbose_name = _("Mapbox") + is_timeseries = False + credits = "Mapbox GL JS" + + def query_obj(self): + d = super().query_obj() + fd = self.form_data + label_col = fd.get("mapbox_label") + + if not fd.get("groupby"): + if fd.get("all_columns_x") is None or fd.get("all_columns_y") is None: + raise Exception(_("[Longitude] and [Latitude] must be set")) + d["columns"] = [fd.get("all_columns_x"), fd.get("all_columns_y")] + + if label_col and len(label_col) >= 1: + if label_col[0] == "count": + raise Exception( + _( + "Must have a [Group By] column to have 'count' as the " + + "[Label]" + ) + ) + d["columns"].append(label_col[0]) + + if fd.get("point_radius") != "Auto": + d["columns"].append(fd.get("point_radius")) + + d["columns"] = list(set(d["columns"])) + else: + # Ensuring columns chosen are all in group by + if ( + label_col + and len(label_col) >= 1 + and label_col[0] != "count" + and label_col[0] not in fd.get("groupby") + ): + raise Exception(_("Choice of [Label] must be present in [Group By]")) + + if fd.get("point_radius") != "Auto" and fd.get( + "point_radius" + ) not in fd.get("groupby"): + raise Exception( + _("Choice of [Point Radius] must be present in [Group By]") + ) + + if fd.get("all_columns_x") not in fd.get("groupby") or fd.get( + "all_columns_y" + ) not in fd.get("groupby"): + raise Exception( + _( + "[Longitude] and [Latitude] columns must be present in " + + "[Group By]" + ) + ) + return d + + def get_data(self, df: pd.DataFrame) -> VizData: + if df.empty: + return None + + fd = self.form_data + label_col = fd.get("mapbox_label") + has_custom_metric = label_col is not None and len(label_col) > 0 + metric_col = [None] * len(df.index) + if has_custom_metric: + if label_col[0] == fd.get("all_columns_x"): # type: ignore + metric_col = df[fd.get("all_columns_x")] + elif label_col[0] == fd.get("all_columns_y"): # type: ignore + metric_col = df[fd.get("all_columns_y")] + else: + metric_col = df[label_col[0]] # type: ignore + point_radius_col = ( + [None] * len(df.index) + if fd.get("point_radius") == "Auto" + else df[fd.get("point_radius")] + ) + + # limiting geo precision as long decimal values trigger issues + # around json-bignumber in Mapbox + GEO_PRECISION = 10 + # using geoJSON formatting + geo_json = { + "type": "FeatureCollection", + "features": [ + { + "type": "Feature", + "properties": {"metric": metric, "radius": point_radius}, + "geometry": { + "type": "Point", + "coordinates": [ + round(lon, GEO_PRECISION), + round(lat, GEO_PRECISION), + ], + }, + } + for lon, lat, metric, point_radius in zip( + df[fd.get("all_columns_x")], + df[fd.get("all_columns_y")], + metric_col, + point_radius_col, + ) + ], + } + + x_series, y_series = df[fd.get("all_columns_x")], df[fd.get("all_columns_y")] + south_west = [x_series.min(), y_series.min()] + north_east = [x_series.max(), y_series.max()] + + return { + "geoJSON": geo_json, + "hasCustomMetric": has_custom_metric, + "mapboxApiKey": config["MAPBOX_API_KEY"], + "mapStyle": fd.get("mapbox_style"), + "aggregatorName": fd.get("pandas_aggfunc"), + "clusteringRadius": fd.get("clustering_radius"), + "pointRadiusUnit": fd.get("point_radius_unit"), + "globalOpacity": fd.get("global_opacity"), + "bounds": [south_west, north_east], + "renderWhileDragging": fd.get("render_while_dragging"), + "tooltip": fd.get("rich_tooltip"), + "color": fd.get("mapbox_color"), + } + + +class DeckGLMultiLayer(BaseViz): + + """Pile on multiple DeckGL layers""" + + viz_type = "deck_multi" + verbose_name = _("Deck.gl - Multiple Layers") + + is_timeseries = False + credits = 'deck.gl' + + def query_obj(self): + return None + + def get_data(self, df: pd.DataFrame) -> VizData: + fd = self.form_data + # Late imports to avoid circular import issues + from superset.models.slice import Slice + from superset import db + + slice_ids = fd.get("deck_slices") + slices = db.session.query(Slice).filter(Slice.id.in_(slice_ids)).all() + return { + "mapboxApiKey": config["MAPBOX_API_KEY"], + "slices": [slc.data for slc in slices], + } + + +class BaseDeckGLViz(BaseViz): + + """Base class for deck.gl visualizations""" + + is_timeseries = False + credits = 'deck.gl' + spatial_control_keys: List[str] = [] + + def get_metrics(self): + self.metric = self.form_data.get("size") + return [self.metric] if self.metric else [] + + @staticmethod + def parse_coordinates(s): + if not s: + return None + try: + p = Point(s) + return (p.latitude, p.longitude) # pylint: disable=no-member + except Exception: + raise SpatialException(_("Invalid spatial point encountered: %s" % s)) + + @staticmethod + def reverse_geohash_decode(geohash_code): + lat, lng = geohash.decode(geohash_code) + return (lng, lat) + + @staticmethod + def reverse_latlong(df, key): + df[key] = [tuple(reversed(o)) for o in df[key] if isinstance(o, (list, tuple))] + + def process_spatial_data_obj(self, key, df): + spatial = self.form_data.get(key) + if spatial is None: + raise ValueError(_("Bad spatial key")) + + if spatial.get("type") == "latlong": + df[key] = list( + zip( + pd.to_numeric(df[spatial.get("lonCol")], errors="coerce"), + pd.to_numeric(df[spatial.get("latCol")], errors="coerce"), + ) + ) + elif spatial.get("type") == "delimited": + lon_lat_col = spatial.get("lonlatCol") + df[key] = df[lon_lat_col].apply(self.parse_coordinates) + del df[lon_lat_col] + elif spatial.get("type") == "geohash": + df[key] = df[spatial.get("geohashCol")].map(self.reverse_geohash_decode) + del df[spatial.get("geohashCol")] + + if spatial.get("reverseCheckbox"): + self.reverse_latlong(df, key) + + if df.get(key) is None: + raise NullValueException( + _( + "Encountered invalid NULL spatial entry, \ + please consider filtering those out" + ) + ) + return df + + def add_null_filters(self): + fd = self.form_data + spatial_columns = set() + + if fd.get("adhoc_filters") is None: + fd["adhoc_filters"] = [] + + line_column = fd.get("line_column") + if line_column: + spatial_columns.add(line_column) + + for column in sorted(spatial_columns): + filter_ = to_adhoc({"col": column, "op": "IS NOT NULL", "val": ""}) + fd["adhoc_filters"].append(filter_) + + def query_obj(self): + fd = self.form_data + + # add NULL filters + if fd.get("filter_nulls", True): + self.add_null_filters() + + d = super().query_obj() + + metrics = self.get_metrics() + if metrics: + d["metrics"] = metrics + return d + + def get_js_columns(self, d): + cols = self.form_data.get("js_columns") or [] + return {col: d.get(col) for col in cols} + + def get_data(self, df: pd.DataFrame) -> VizData: + if df.empty: + return None + + # Processing spatial info + for key in self.spatial_control_keys: + df = self.process_spatial_data_obj(key, df) + + features = [] + for d in df.to_dict(orient="records"): + feature = self.get_properties(d) + extra_props = self.get_js_columns(d) + if extra_props: + feature["extraProps"] = extra_props + features.append(feature) + + return { + "features": features, + "mapboxApiKey": config["MAPBOX_API_KEY"], + "metricLabels": self.metric_labels, + } + + def get_properties(self, d): + raise NotImplementedError() + + +class DeckScatterViz(BaseDeckGLViz): + + """deck.gl's ScatterLayer""" + + viz_type = "deck_scatter" + verbose_name = _("Deck.gl - Scatter plot") + spatial_control_keys = ["spatial"] + is_timeseries = True + + def query_obj(self): + fd = self.form_data + self.is_timeseries = bool(fd.get("time_grain_sqla") or fd.get("granularity")) + self.point_radius_fixed = fd.get("point_radius_fixed") or { + "type": "fix", + "value": 500, + } + return super().query_obj() + + def get_metrics(self): + self.metric = None + if self.point_radius_fixed.get("type") == "metric": + self.metric = self.point_radius_fixed.get("value") + return [self.metric] + return None + + def get_properties(self, d): + return { + "metric": d.get(self.metric_label), + "radius": self.fixed_value + if self.fixed_value + else d.get(self.metric_label), + "cat_color": d.get(self.dim) if self.dim else None, + "position": d.get("spatial"), + DTTM_ALIAS: d.get(DTTM_ALIAS), + } + + def get_data(self, df: pd.DataFrame) -> VizData: + fd = self.form_data + self.metric_label = utils.get_metric_name(self.metric) if self.metric else None + self.point_radius_fixed = fd.get("point_radius_fixed") + self.fixed_value = None + self.dim = self.form_data.get("dimension") + if self.point_radius_fixed and self.point_radius_fixed.get("type") != "metric": + self.fixed_value = self.point_radius_fixed.get("value") + return super().get_data(df) + + +class DeckScreengrid(BaseDeckGLViz): + + """deck.gl's ScreenGridLayer""" + + viz_type = "deck_screengrid" + verbose_name = _("Deck.gl - Screen Grid") + spatial_control_keys = ["spatial"] + is_timeseries = True + + def query_obj(self): + fd = self.form_data + self.is_timeseries = fd.get("time_grain_sqla") or fd.get("granularity") + return super().query_obj() + + def get_properties(self, d): + return { + "position": d.get("spatial"), + "weight": d.get(self.metric_label) or 1, + "__timestamp": d.get(DTTM_ALIAS) or d.get("__time"), + } + + def get_data(self, df: pd.DataFrame) -> VizData: + self.metric_label = utils.get_metric_name(self.metric) + return super().get_data(df) + + +class DeckGrid(BaseDeckGLViz): + + """deck.gl's DeckLayer""" + + viz_type = "deck_grid" + verbose_name = _("Deck.gl - 3D Grid") + spatial_control_keys = ["spatial"] + + def get_properties(self, d): + return {"position": d.get("spatial"), "weight": d.get(self.metric_label) or 1} + + def get_data(self, df: pd.DataFrame) -> VizData: + self.metric_label = utils.get_metric_name(self.metric) + return super().get_data(df) + + +def geohash_to_json(geohash_code): + p = geohash.bbox(geohash_code) + return [ + [p.get("w"), p.get("n")], + [p.get("e"), p.get("n")], + [p.get("e"), p.get("s")], + [p.get("w"), p.get("s")], + [p.get("w"), p.get("n")], + ] + + +class DeckPathViz(BaseDeckGLViz): + + """deck.gl's PathLayer""" + + viz_type = "deck_path" + verbose_name = _("Deck.gl - Paths") + deck_viz_key = "path" + is_timeseries = True + deser_map = { + "json": json.loads, + "polyline": polyline.decode, + "geohash": geohash_to_json, + } + + def query_obj(self): + fd = self.form_data + self.is_timeseries = fd.get("time_grain_sqla") or fd.get("granularity") + d = super().query_obj() + self.metric = fd.get("metric") + if d["metrics"]: + self.has_metrics = True + else: + self.has_metrics = False + return d + + def get_properties(self, d): + fd = self.form_data + line_type = fd.get("line_type") + deser = self.deser_map[line_type] + line_column = fd.get("line_column") + path = deser(d[line_column]) + if fd.get("reverse_long_lat"): + path = [(o[1], o[0]) for o in path] + d[self.deck_viz_key] = path + if line_type != "geohash": + del d[line_column] + d["__timestamp"] = d.get(DTTM_ALIAS) or d.get("__time") + return d + + def get_data(self, df: pd.DataFrame) -> VizData: + self.metric_label = utils.get_metric_name(self.metric) + return super().get_data(df) + + +class DeckPolygon(DeckPathViz): + + """deck.gl's Polygon Layer""" + + viz_type = "deck_polygon" + deck_viz_key = "polygon" + verbose_name = _("Deck.gl - Polygon") + + def query_obj(self): + fd = self.form_data + self.elevation = fd.get("point_radius_fixed") or {"type": "fix", "value": 500} + return super().query_obj() + + def get_metrics(self): + metrics = [self.form_data.get("metric")] + if self.elevation.get("type") == "metric": + metrics.append(self.elevation.get("value")) + return [metric for metric in metrics if metric] + + def get_properties(self, d): + super().get_properties(d) + fd = self.form_data + elevation = fd["point_radius_fixed"]["value"] + type_ = fd["point_radius_fixed"]["type"] + d["elevation"] = ( + d.get(utils.get_metric_name(elevation)) if type_ == "metric" else elevation + ) + return d + + +class DeckHex(BaseDeckGLViz): + + """deck.gl's DeckLayer""" + + viz_type = "deck_hex" + verbose_name = _("Deck.gl - 3D HEX") + spatial_control_keys = ["spatial"] + + def get_properties(self, d): + return {"position": d.get("spatial"), "weight": d.get(self.metric_label) or 1} + + def get_data(self, df: pd.DataFrame) -> VizData: + self.metric_label = utils.get_metric_name(self.metric) + return super(DeckHex, self).get_data(df) + + +class DeckGeoJson(BaseDeckGLViz): + + """deck.gl's GeoJSONLayer""" + + viz_type = "deck_geojson" + verbose_name = _("Deck.gl - GeoJSON") + + def get_properties(self, d): + geojson = d.get(self.form_data.get("geojson")) + return json.loads(geojson) + + +class DeckArc(BaseDeckGLViz): + + """deck.gl's Arc Layer""" + + viz_type = "deck_arc" + verbose_name = _("Deck.gl - Arc") + spatial_control_keys = ["start_spatial", "end_spatial"] + is_timeseries = True + + def query_obj(self): + fd = self.form_data + self.is_timeseries = bool(fd.get("time_grain_sqla") or fd.get("granularity")) + return super().query_obj() + + def get_properties(self, d): + dim = self.form_data.get("dimension") + return { + "sourcePosition": d.get("start_spatial"), + "targetPosition": d.get("end_spatial"), + "cat_color": d.get(dim) if dim else None, + DTTM_ALIAS: d.get(DTTM_ALIAS), + } + + def get_data(self, df: pd.DataFrame) -> VizData: + if df.empty: + return None + + d = super().get_data(df) + + return { + "features": d["features"], # type: ignore + "mapboxApiKey": config["MAPBOX_API_KEY"], + } + + +class EventFlowViz(BaseViz): + + """A visualization to explore patterns in event sequences""" + + viz_type = "event_flow" + verbose_name = _("Event flow") + credits = 'from @data-ui' + is_timeseries = True + + def query_obj(self): + query = super().query_obj() + form_data = self.form_data + + event_key = form_data.get("all_columns_x") + entity_key = form_data.get("entity") + meta_keys = [ + col + for col in form_data.get("all_columns") + if col != event_key and col != entity_key + ] + + query["columns"] = [event_key, entity_key] + meta_keys + + if form_data["order_by_entity"]: + query["orderby"] = [(entity_key, True)] + + return query + + def get_data(self, df: pd.DataFrame) -> VizData: + return df.to_dict(orient="records") + + +class PairedTTestViz(BaseViz): + + """A table displaying paired t-test values""" + + viz_type = "paired_ttest" + verbose_name = _("Time Series - Paired t-test") + sort_series = False + is_timeseries = True + + def get_data(self, df: pd.DataFrame) -> VizData: + """ + Transform received data frame into an object of the form: + { + 'metric1': [ + { + groups: ('groupA', ... ), + values: [ {x, y}, ... ], + }, ... + ], ... + } + """ + + if df.empty: + return None + + fd = self.form_data + groups = fd.get("groupby") + metrics = self.metric_labels + df = df.pivot_table(index=DTTM_ALIAS, columns=groups, values=metrics) + cols = [] + # Be rid of falsey keys + for col in df.columns: + if col == "": + cols.append("N/A") + elif col is None: + cols.append("NULL") + else: + cols.append(col) + df.columns = cols + data: Dict = {} + series = df.to_dict("series") + for nameSet in df.columns: + # If no groups are defined, nameSet will be the metric name + hasGroup = not isinstance(nameSet, str) + Y = series[nameSet] + d = { + "group": nameSet[1:] if hasGroup else "All", + "values": [{"x": t, "y": Y[t] if t in Y else None} for t in df.index], + } + key = nameSet[0] if hasGroup else nameSet + if key in data: + data[key].append(d) + else: + data[key] = [d] + return data + + +class RoseViz(NVD3TimeSeriesViz): + + viz_type = "rose" + verbose_name = _("Time Series - Nightingale Rose Chart") + sort_series = False + is_timeseries = True + + def get_data(self, df: pd.DataFrame) -> VizData: + if df.empty: + return None + + data = super().get_data(df) + result: Dict = {} + for datum in data: # type: ignore + key = datum["key"] + for val in datum["values"]: + timestamp = val["x"].value + if not result.get(timestamp): + result[timestamp] = [] + value = 0 if math.isnan(val["y"]) else val["y"] + result[timestamp].append( + { + "key": key, + "value": value, + "name": ", ".join(key) if isinstance(key, list) else key, + "time": val["x"], + } + ) + return result + + +class PartitionViz(NVD3TimeSeriesViz): + + """ + A hierarchical data visualization with support for time series. + """ + + viz_type = "partition" + verbose_name = _("Partition Diagram") + + def query_obj(self): + query_obj = super().query_obj() + time_op = self.form_data.get("time_series_option", "not_time") + # Return time series data if the user specifies so + query_obj["is_timeseries"] = time_op != "not_time" + return query_obj + + def levels_for(self, time_op, groups, df): + """ + Compute the partition at each `level` from the dataframe. + """ + levels = {} + for i in range(0, len(groups) + 1): + agg_df = df.groupby(groups[:i]) if i else df + levels[i] = ( + agg_df.mean() + if time_op == "agg_mean" + else agg_df.sum(numeric_only=True) + ) + return levels + + def levels_for_diff(self, time_op, groups, df): + # Obtain a unique list of the time grains + times = list(set(df[DTTM_ALIAS])) + times.sort() + until = times[len(times) - 1] + since = times[0] + # Function describing how to calculate the difference + func = { + "point_diff": [pd.Series.sub, lambda a, b, fill_value: a - b], + "point_factor": [pd.Series.div, lambda a, b, fill_value: a / float(b)], + "point_percent": [ + lambda a, b, fill_value=0: a.div(b, fill_value=fill_value) - 1, + lambda a, b, fill_value: a / float(b) - 1, + ], + }[time_op] + agg_df = df.groupby(DTTM_ALIAS).sum() + levels = { + 0: pd.Series( + { + m: func[1](agg_df[m][until], agg_df[m][since], 0) + for m in agg_df.columns + } + ) + } + for i in range(1, len(groups) + 1): + agg_df = df.groupby([DTTM_ALIAS] + groups[:i]).sum() + levels[i] = pd.DataFrame( + { + m: func[0](agg_df[m][until], agg_df[m][since], fill_value=0) + for m in agg_df.columns + } + ) + return levels + + def levels_for_time(self, groups, df): + procs = {} + for i in range(0, len(groups) + 1): + self.form_data["groupby"] = groups[:i] + df_drop = df.drop(groups[i:], 1) + procs[i] = self.process_data(df_drop, aggregate=True) + self.form_data["groupby"] = groups + return procs + + def nest_values(self, levels, level=0, metric=None, dims=()): + """ + Nest values at each level on the back-end with + access and setting, instead of summing from the bottom. + """ + if not level: + return [ + { + "name": m, + "val": levels[0][m], + "children": self.nest_values(levels, 1, m), + } + for m in levels[0].index + ] + if level == 1: + return [ + { + "name": i, + "val": levels[1][metric][i], + "children": self.nest_values(levels, 2, metric, (i,)), + } + for i in levels[1][metric].index + ] + if level >= len(levels): + return [] + return [ + { + "name": i, + "val": levels[level][metric][dims][i], + "children": self.nest_values(levels, level + 1, metric, dims + (i,)), + } + for i in levels[level][metric][dims].index + ] + + def nest_procs(self, procs, level=-1, dims=(), time=None): + if level == -1: + return [ + {"name": m, "children": self.nest_procs(procs, 0, (m,))} + for m in procs[0].columns + ] + if not level: + return [ + { + "name": t, + "val": procs[0][dims[0]][t], + "children": self.nest_procs(procs, 1, dims, t), + } + for t in procs[0].index + ] + if level >= len(procs): + return [] + return [ + { + "name": i, + "val": procs[level][dims][i][time], + "children": self.nest_procs(procs, level + 1, dims + (i,), time), + } + for i in procs[level][dims].columns + ] + + def get_data(self, df: pd.DataFrame) -> VizData: + fd = self.form_data + groups = fd.get("groupby", []) + time_op = fd.get("time_series_option", "not_time") + if not len(groups): + raise ValueError("Please choose at least one groupby") + if time_op == "not_time": + levels = self.levels_for("agg_sum", groups, df) + elif time_op in ["agg_sum", "agg_mean"]: + levels = self.levels_for(time_op, groups, df) + elif time_op in ["point_diff", "point_factor", "point_percent"]: + levels = self.levels_for_diff(time_op, groups, df) + elif time_op == "adv_anal": + procs = self.levels_for_time(groups, df) + return self.nest_procs(procs) + else: + levels = self.levels_for("agg_sum", [DTTM_ALIAS] + groups, df) + return self.nest_values(levels) + + +viz_types = { + o.viz_type: o + for o in globals().values() + if ( + inspect.isclass(o) + and issubclass(o, BaseViz) + and o.viz_type not in config["VIZ_TYPE_BLACKLIST"] + ) +} diff --git a/tests/druid_func_tests.py b/tests/druid_func_tests.py index 1ed7057bc791..699afeaec835 100644 --- a/tests/druid_func_tests.py +++ b/tests/druid_func_tests.py @@ -415,11 +415,11 @@ def test_run_query_no_groupby(self): client.query_builder.last_query.query_dict = {"mock": 0} # no groupby calls client.timeseries ds.run_query( - groupby, metrics, None, from_dttm, to_dttm, + groupby=groupby, client=client, filter=[], row_limit=100, @@ -472,11 +472,11 @@ def test_run_query_with_adhoc_metric(self): client.query_builder.last_query.query_dict = {"mock": 0} # no groupby calls client.timeseries ds.run_query( - groupby, metrics, None, from_dttm, to_dttm, + groupby=groupby, client=client, filter=[], row_limit=100, @@ -519,11 +519,11 @@ def test_run_query_single_groupby(self): client.query_builder.last_query.query_dict = {"mock": 0} # client.topn is called twice ds.run_query( - groupby, metrics, None, from_dttm, to_dttm, + groupby=groupby, timeseries_limit=100, client=client, order_desc=True, @@ -543,11 +543,11 @@ def test_run_query_single_groupby(self): client = Mock() client.query_builder.last_query.query_dict = {"mock": 0} ds.run_query( - groupby, metrics, None, from_dttm, to_dttm, + groupby=groupby, client=client, order_desc=False, filter=[], @@ -568,11 +568,11 @@ def test_run_query_single_groupby(self): client = Mock() client.query_builder.last_query.query_dict = {"mock": 0} ds.run_query( - groupby, metrics, None, from_dttm, to_dttm, + groupby=groupby, client=client, order_desc=True, timeseries_limit=5, @@ -619,11 +619,11 @@ def test_run_query_multiple_groupby(self): client.query_builder.last_query.query_dict = {"mock": 0} # no groupby calls client.timeseries ds.run_query( - groupby, metrics, None, from_dttm, to_dttm, + groupby=groupby, client=client, row_limit=100, filter=[], @@ -1021,11 +1021,11 @@ def test_run_query_order_by_metrics(self): granularity = "all" # get the counts of the top 5 'dim1's, order by 'sum1' ds.run_query( - groupby, metrics, granularity, from_dttm, to_dttm, + groupby=groupby, timeseries_limit=5, timeseries_limit_metric="sum1", client=client, @@ -1042,11 +1042,11 @@ def test_run_query_order_by_metrics(self): # get the counts of the top 5 'dim1's, order by 'div1' ds.run_query( - groupby, metrics, granularity, from_dttm, to_dttm, + groupby=groupby, timeseries_limit=5, timeseries_limit_metric="div1", client=client, @@ -1064,11 +1064,11 @@ def test_run_query_order_by_metrics(self): groupby = ["dim1", "dim2"] # get the counts of the top 5 ['dim1', 'dim2']s, order by 'sum1' ds.run_query( - groupby, metrics, granularity, from_dttm, to_dttm, + groupby=groupby, timeseries_limit=5, timeseries_limit_metric="sum1", client=client, @@ -1085,11 +1085,11 @@ def test_run_query_order_by_metrics(self): # get the counts of the top 5 ['dim1', 'dim2']s, order by 'div1' ds.run_query( - groupby, metrics, granularity, from_dttm, to_dttm, + groupby=groupby, timeseries_limit=5, timeseries_limit_metric="div1", client=client, diff --git a/tests/druid_func_tests_sip38.py b/tests/druid_func_tests_sip38.py new file mode 100644 index 000000000000..058d8c1743bd --- /dev/null +++ b/tests/druid_func_tests_sip38.py @@ -0,0 +1,1157 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# isort:skip_file +import json +import unittest +from unittest.mock import Mock, patch + +import tests.test_app +import superset.connectors.druid.models as models +from superset.connectors.druid.models import DruidColumn, DruidDatasource, DruidMetric +from superset.exceptions import SupersetException + +from .base_tests import SupersetTestCase + +try: + from pydruid.utils.dimensions import ( + MapLookupExtraction, + RegexExtraction, + RegisteredLookupExtraction, + TimeFormatExtraction, + ) + import pydruid.utils.postaggregator as postaggs +except ImportError: + pass + + +def mock_metric(metric_name, is_postagg=False): + metric = Mock() + metric.metric_name = metric_name + metric.metric_type = "postagg" if is_postagg else "metric" + return metric + + +def emplace(metrics_dict, metric_name, is_postagg=False): + metrics_dict[metric_name] = mock_metric(metric_name, is_postagg) + + +# Unit tests that can be run without initializing base tests +@patch.dict( + "superset.extensions.feature_flag_manager._feature_flags", + {"SIP_38_VIZ_REARCHITECTURE": True}, + clear=True, +) +class DruidFuncTestCase(SupersetTestCase): + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_get_filters_extraction_fn_map(self): + filters = [{"col": "deviceName", "val": ["iPhone X"], "op": "in"}] + dimension_spec = { + "type": "extraction", + "dimension": "device", + "outputName": "deviceName", + "outputType": "STRING", + "extractionFn": { + "type": "lookup", + "dimension": "dimensionName", + "outputName": "dimensionOutputName", + "replaceMissingValueWith": "missing_value", + "retainMissingValue": False, + "lookup": { + "type": "map", + "map": { + "iPhone10,1": "iPhone 8", + "iPhone10,4": "iPhone 8", + "iPhone10,2": "iPhone 8 Plus", + "iPhone10,5": "iPhone 8 Plus", + "iPhone10,3": "iPhone X", + "iPhone10,6": "iPhone X", + }, + "isOneToOne": False, + }, + }, + } + spec_json = json.dumps(dimension_spec) + col = DruidColumn(column_name="deviceName", dimension_spec_json=spec_json) + column_dict = {"deviceName": col} + f = DruidDatasource.get_filters(filters, [], column_dict) + assert isinstance(f.extraction_function, MapLookupExtraction) + dim_ext_fn = dimension_spec["extractionFn"] + f_ext_fn = f.extraction_function + self.assertEqual(dim_ext_fn["lookup"]["map"], f_ext_fn._mapping) + self.assertEqual(dim_ext_fn["lookup"]["isOneToOne"], f_ext_fn._injective) + self.assertEqual( + dim_ext_fn["replaceMissingValueWith"], f_ext_fn._replace_missing_values + ) + self.assertEqual( + dim_ext_fn["retainMissingValue"], f_ext_fn._retain_missing_values + ) + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_get_filters_extraction_fn_regex(self): + filters = [{"col": "buildPrefix", "val": ["22B"], "op": "in"}] + dimension_spec = { + "type": "extraction", + "dimension": "build", + "outputName": "buildPrefix", + "outputType": "STRING", + "extractionFn": {"type": "regex", "expr": "(^[0-9A-Za-z]{3})"}, + } + spec_json = json.dumps(dimension_spec) + col = DruidColumn(column_name="buildPrefix", dimension_spec_json=spec_json) + column_dict = {"buildPrefix": col} + f = DruidDatasource.get_filters(filters, [], column_dict) + assert isinstance(f.extraction_function, RegexExtraction) + dim_ext_fn = dimension_spec["extractionFn"] + f_ext_fn = f.extraction_function + self.assertEqual(dim_ext_fn["expr"], f_ext_fn._expr) + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_get_filters_extraction_fn_registered_lookup_extraction(self): + filters = [{"col": "country", "val": ["Spain"], "op": "in"}] + dimension_spec = { + "type": "extraction", + "dimension": "country_name", + "outputName": "country", + "outputType": "STRING", + "extractionFn": {"type": "registeredLookup", "lookup": "country_name"}, + } + spec_json = json.dumps(dimension_spec) + col = DruidColumn(column_name="country", dimension_spec_json=spec_json) + column_dict = {"country": col} + f = DruidDatasource.get_filters(filters, [], column_dict) + assert isinstance(f.extraction_function, RegisteredLookupExtraction) + dim_ext_fn = dimension_spec["extractionFn"] + self.assertEqual(dim_ext_fn["type"], f.extraction_function.extraction_type) + self.assertEqual(dim_ext_fn["lookup"], f.extraction_function._lookup) + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_get_filters_extraction_fn_time_format(self): + filters = [{"col": "dayOfMonth", "val": ["1", "20"], "op": "in"}] + dimension_spec = { + "type": "extraction", + "dimension": "__time", + "outputName": "dayOfMonth", + "extractionFn": { + "type": "timeFormat", + "format": "d", + "timeZone": "Asia/Kolkata", + "locale": "en", + }, + } + spec_json = json.dumps(dimension_spec) + col = DruidColumn(column_name="dayOfMonth", dimension_spec_json=spec_json) + column_dict = {"dayOfMonth": col} + f = DruidDatasource.get_filters(filters, [], column_dict) + assert isinstance(f.extraction_function, TimeFormatExtraction) + dim_ext_fn = dimension_spec["extractionFn"] + self.assertEqual(dim_ext_fn["type"], f.extraction_function.extraction_type) + self.assertEqual(dim_ext_fn["format"], f.extraction_function._format) + self.assertEqual(dim_ext_fn["timeZone"], f.extraction_function._time_zone) + self.assertEqual(dim_ext_fn["locale"], f.extraction_function._locale) + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_get_filters_ignores_invalid_filter_objects(self): + filtr = {"col": "col1", "op": "=="} + filters = [filtr] + col = DruidColumn(column_name="col1") + column_dict = {"col1": col} + self.assertIsNone(DruidDatasource.get_filters(filters, [], column_dict)) + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_get_filters_constructs_filter_in(self): + filtr = {"col": "A", "op": "in", "val": ["a", "b", "c"]} + col = DruidColumn(column_name="A") + column_dict = {"A": col} + res = DruidDatasource.get_filters([filtr], [], column_dict) + self.assertIn("filter", res.filter) + self.assertIn("fields", res.filter["filter"]) + self.assertEqual("or", res.filter["filter"]["type"]) + self.assertEqual(3, len(res.filter["filter"]["fields"])) + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_get_filters_constructs_filter_not_in(self): + filtr = {"col": "A", "op": "not in", "val": ["a", "b", "c"]} + col = DruidColumn(column_name="A") + column_dict = {"A": col} + res = DruidDatasource.get_filters([filtr], [], column_dict) + self.assertIn("filter", res.filter) + self.assertIn("type", res.filter["filter"]) + self.assertEqual("not", res.filter["filter"]["type"]) + self.assertIn("field", res.filter["filter"]) + self.assertEqual( + 3, len(res.filter["filter"]["field"].filter["filter"]["fields"]) + ) + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_get_filters_constructs_filter_equals(self): + filtr = {"col": "A", "op": "==", "val": "h"} + col = DruidColumn(column_name="A") + column_dict = {"A": col} + res = DruidDatasource.get_filters([filtr], [], column_dict) + self.assertEqual("selector", res.filter["filter"]["type"]) + self.assertEqual("A", res.filter["filter"]["dimension"]) + self.assertEqual("h", res.filter["filter"]["value"]) + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_get_filters_constructs_filter_not_equals(self): + filtr = {"col": "A", "op": "!=", "val": "h"} + col = DruidColumn(column_name="A") + column_dict = {"A": col} + res = DruidDatasource.get_filters([filtr], [], column_dict) + self.assertEqual("not", res.filter["filter"]["type"]) + self.assertEqual("h", res.filter["filter"]["field"].filter["filter"]["value"]) + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_get_filters_constructs_bounds_filter(self): + filtr = {"col": "A", "op": ">=", "val": "h"} + col = DruidColumn(column_name="A") + column_dict = {"A": col} + res = DruidDatasource.get_filters([filtr], [], column_dict) + self.assertFalse(res.filter["filter"]["lowerStrict"]) + self.assertEqual("A", res.filter["filter"]["dimension"]) + self.assertEqual("h", res.filter["filter"]["lower"]) + self.assertEqual("lexicographic", res.filter["filter"]["ordering"]) + filtr["op"] = ">" + res = DruidDatasource.get_filters([filtr], [], column_dict) + self.assertTrue(res.filter["filter"]["lowerStrict"]) + filtr["op"] = "<=" + res = DruidDatasource.get_filters([filtr], [], column_dict) + self.assertFalse(res.filter["filter"]["upperStrict"]) + self.assertEqual("h", res.filter["filter"]["upper"]) + filtr["op"] = "<" + res = DruidDatasource.get_filters([filtr], [], column_dict) + self.assertTrue(res.filter["filter"]["upperStrict"]) + filtr["val"] = 1 + res = DruidDatasource.get_filters([filtr], ["A"], column_dict) + self.assertEqual("numeric", res.filter["filter"]["ordering"]) + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_get_filters_is_null_filter(self): + filtr = {"col": "A", "op": "IS NULL"} + col = DruidColumn(column_name="A") + column_dict = {"A": col} + res = DruidDatasource.get_filters([filtr], [], column_dict) + self.assertEqual("selector", res.filter["filter"]["type"]) + self.assertEqual("", res.filter["filter"]["value"]) + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_get_filters_is_not_null_filter(self): + filtr = {"col": "A", "op": "IS NOT NULL"} + col = DruidColumn(column_name="A") + column_dict = {"A": col} + res = DruidDatasource.get_filters([filtr], [], column_dict) + self.assertEqual("not", res.filter["filter"]["type"]) + self.assertIn("field", res.filter["filter"]) + self.assertEqual( + "selector", res.filter["filter"]["field"].filter["filter"]["type"] + ) + self.assertEqual("", res.filter["filter"]["field"].filter["filter"]["value"]) + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_get_filters_constructs_regex_filter(self): + filtr = {"col": "A", "op": "regex", "val": "[abc]"} + col = DruidColumn(column_name="A") + column_dict = {"A": col} + res = DruidDatasource.get_filters([filtr], [], column_dict) + self.assertEqual("regex", res.filter["filter"]["type"]) + self.assertEqual("[abc]", res.filter["filter"]["pattern"]) + self.assertEqual("A", res.filter["filter"]["dimension"]) + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_get_filters_composes_multiple_filters(self): + filtr1 = {"col": "A", "op": "!=", "val": "y"} + filtr2 = {"col": "B", "op": "in", "val": ["a", "b", "c"]} + cola = DruidColumn(column_name="A") + colb = DruidColumn(column_name="B") + column_dict = {"A": cola, "B": colb} + res = DruidDatasource.get_filters([filtr1, filtr2], [], column_dict) + self.assertEqual("and", res.filter["filter"]["type"]) + self.assertEqual(2, len(res.filter["filter"]["fields"])) + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_get_filters_ignores_in_not_in_with_empty_value(self): + filtr1 = {"col": "A", "op": "in", "val": []} + filtr2 = {"col": "A", "op": "not in", "val": []} + col = DruidColumn(column_name="A") + column_dict = {"A": col} + res = DruidDatasource.get_filters([filtr1, filtr2], [], column_dict) + self.assertIsNone(res) + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_get_filters_constructs_equals_for_in_not_in_single_value(self): + filtr = {"col": "A", "op": "in", "val": ["a"]} + cola = DruidColumn(column_name="A") + colb = DruidColumn(column_name="B") + column_dict = {"A": cola, "B": colb} + res = DruidDatasource.get_filters([filtr], [], column_dict) + self.assertEqual("selector", res.filter["filter"]["type"]) + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_get_filters_handles_arrays_for_string_types(self): + filtr = {"col": "A", "op": "==", "val": ["a", "b"]} + col = DruidColumn(column_name="A") + column_dict = {"A": col} + res = DruidDatasource.get_filters([filtr], [], column_dict) + self.assertEqual("a", res.filter["filter"]["value"]) + + filtr = {"col": "A", "op": "==", "val": []} + res = DruidDatasource.get_filters([filtr], [], column_dict) + self.assertIsNone(res.filter["filter"]["value"]) + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_get_filters_handles_none_for_string_types(self): + filtr = {"col": "A", "op": "==", "val": None} + col = DruidColumn(column_name="A") + column_dict = {"A": col} + res = DruidDatasource.get_filters([filtr], [], column_dict) + self.assertIsNone(res) + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_get_filters_extracts_values_in_quotes(self): + filtr = {"col": "A", "op": "in", "val": ['"a"']} + col = DruidColumn(column_name="A") + column_dict = {"A": col} + res = DruidDatasource.get_filters([filtr], [], column_dict) + self.assertEqual("a", res.filter["filter"]["value"]) + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_get_filters_keeps_trailing_spaces(self): + filtr = {"col": "A", "op": "in", "val": ["a "]} + col = DruidColumn(column_name="A") + column_dict = {"A": col} + res = DruidDatasource.get_filters([filtr], [], column_dict) + self.assertEqual("a ", res.filter["filter"]["value"]) + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_get_filters_converts_strings_to_num(self): + filtr = {"col": "A", "op": "in", "val": ["6"]} + col = DruidColumn(column_name="A") + column_dict = {"A": col} + res = DruidDatasource.get_filters([filtr], ["A"], column_dict) + self.assertEqual(6, res.filter["filter"]["value"]) + filtr = {"col": "A", "op": "==", "val": "6"} + res = DruidDatasource.get_filters([filtr], ["A"], column_dict) + self.assertEqual(6, res.filter["filter"]["value"]) + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_run_query_no_groupby(self): + client = Mock() + from_dttm = Mock() + to_dttm = Mock() + from_dttm.replace = Mock(return_value=from_dttm) + to_dttm.replace = Mock(return_value=to_dttm) + from_dttm.isoformat = Mock(return_value="from") + to_dttm.isoformat = Mock(return_value="to") + timezone = "timezone" + from_dttm.tzname = Mock(return_value=timezone) + ds = DruidDatasource(datasource_name="datasource") + metric1 = DruidMetric(metric_name="metric1") + metric2 = DruidMetric(metric_name="metric2") + ds.metrics = [metric1, metric2] + col1 = DruidColumn(column_name="col1") + col2 = DruidColumn(column_name="col2") + ds.columns = [col1, col2] + aggs = [] + post_aggs = ["some_agg"] + ds._metrics_and_post_aggs = Mock(return_value=(aggs, post_aggs)) + columns = [] + metrics = ["metric1"] + ds.get_having_filters = Mock(return_value=[]) + client.query_builder = Mock() + client.query_builder.last_query = Mock() + client.query_builder.last_query.query_dict = {"mock": 0} + # no groupby calls client.timeseries + ds.run_query( + metrics, + None, + from_dttm, + to_dttm, + groupby=columns, + client=client, + filter=[], + row_limit=100, + ) + self.assertEqual(0, len(client.topn.call_args_list)) + self.assertEqual(0, len(client.groupby.call_args_list)) + self.assertEqual(1, len(client.timeseries.call_args_list)) + # check that there is no dimensions entry + called_args = client.timeseries.call_args_list[0][1] + self.assertNotIn("dimensions", called_args) + self.assertIn("post_aggregations", called_args) + # restore functions + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_run_query_with_adhoc_metric(self): + client = Mock() + from_dttm = Mock() + to_dttm = Mock() + from_dttm.replace = Mock(return_value=from_dttm) + to_dttm.replace = Mock(return_value=to_dttm) + from_dttm.isoformat = Mock(return_value="from") + to_dttm.isoformat = Mock(return_value="to") + timezone = "timezone" + from_dttm.tzname = Mock(return_value=timezone) + ds = DruidDatasource(datasource_name="datasource") + metric1 = DruidMetric(metric_name="metric1") + metric2 = DruidMetric(metric_name="metric2") + ds.metrics = [metric1, metric2] + col1 = DruidColumn(column_name="col1") + col2 = DruidColumn(column_name="col2") + ds.columns = [col1, col2] + all_metrics = [] + post_aggs = ["some_agg"] + ds._metrics_and_post_aggs = Mock(return_value=(all_metrics, post_aggs)) + columns = [] + metrics = [ + { + "expressionType": "SIMPLE", + "column": {"type": "DOUBLE", "column_name": "col1"}, + "aggregate": "SUM", + "label": "My Adhoc Metric", + } + ] + + ds.get_having_filters = Mock(return_value=[]) + client.query_builder = Mock() + client.query_builder.last_query = Mock() + client.query_builder.last_query.query_dict = {"mock": 0} + # no groupby calls client.timeseries + ds.run_query( + metrics, + None, + from_dttm, + to_dttm, + groupby=columns, + client=client, + filter=[], + row_limit=100, + ) + self.assertEqual(0, len(client.topn.call_args_list)) + self.assertEqual(0, len(client.groupby.call_args_list)) + self.assertEqual(1, len(client.timeseries.call_args_list)) + # check that there is no dimensions entry + called_args = client.timeseries.call_args_list[0][1] + self.assertNotIn("dimensions", called_args) + self.assertIn("post_aggregations", called_args) + # restore functions + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_run_query_single_groupby(self): + client = Mock() + from_dttm = Mock() + to_dttm = Mock() + from_dttm.replace = Mock(return_value=from_dttm) + to_dttm.replace = Mock(return_value=to_dttm) + from_dttm.isoformat = Mock(return_value="from") + to_dttm.isoformat = Mock(return_value="to") + timezone = "timezone" + from_dttm.tzname = Mock(return_value=timezone) + ds = DruidDatasource(datasource_name="datasource") + metric1 = DruidMetric(metric_name="metric1") + metric2 = DruidMetric(metric_name="metric2") + ds.metrics = [metric1, metric2] + col1 = DruidColumn(column_name="col1") + col2 = DruidColumn(column_name="col2") + ds.columns = [col1, col2] + aggs = ["metric1"] + post_aggs = ["some_agg"] + ds._metrics_and_post_aggs = Mock(return_value=(aggs, post_aggs)) + columns = ["col1"] + metrics = ["metric1"] + ds.get_having_filters = Mock(return_value=[]) + client.query_builder.last_query.query_dict = {"mock": 0} + # client.topn is called twice + ds.run_query( + metrics, + None, + from_dttm, + to_dttm, + groupby=columns, + timeseries_limit=100, + client=client, + order_desc=True, + filter=[], + ) + self.assertEqual(2, len(client.topn.call_args_list)) + self.assertEqual(0, len(client.groupby.call_args_list)) + self.assertEqual(0, len(client.timeseries.call_args_list)) + # check that there is no dimensions entry + called_args_pre = client.topn.call_args_list[0][1] + self.assertNotIn("dimensions", called_args_pre) + self.assertIn("dimension", called_args_pre) + called_args = client.topn.call_args_list[1][1] + self.assertIn("dimension", called_args) + self.assertEqual("col1", called_args["dimension"]) + # not order_desc + client = Mock() + client.query_builder.last_query.query_dict = {"mock": 0} + ds.run_query( + metrics, + None, + from_dttm, + to_dttm, + groupby=columns, + client=client, + order_desc=False, + filter=[], + row_limit=100, + ) + self.assertEqual(0, len(client.topn.call_args_list)) + self.assertEqual(1, len(client.groupby.call_args_list)) + self.assertEqual(0, len(client.timeseries.call_args_list)) + self.assertIn("dimensions", client.groupby.call_args_list[0][1]) + self.assertEqual(["col1"], client.groupby.call_args_list[0][1]["dimensions"]) + # order_desc but timeseries and dimension spec + # calls topn with single dimension spec 'dimension' + spec = {"outputName": "hello", "dimension": "matcho"} + spec_json = json.dumps(spec) + col3 = DruidColumn(column_name="col3", dimension_spec_json=spec_json) + ds.columns.append(col3) + groupby = ["col3"] + client = Mock() + client.query_builder.last_query.query_dict = {"mock": 0} + ds.run_query( + metrics, + None, + from_dttm, + to_dttm, + groupby=groupby, + client=client, + order_desc=True, + timeseries_limit=5, + filter=[], + row_limit=100, + ) + self.assertEqual(2, len(client.topn.call_args_list)) + self.assertEqual(0, len(client.groupby.call_args_list)) + self.assertEqual(0, len(client.timeseries.call_args_list)) + self.assertIn("dimension", client.topn.call_args_list[0][1]) + self.assertIn("dimension", client.topn.call_args_list[1][1]) + # uses dimension for pre query and full spec for final query + self.assertEqual("matcho", client.topn.call_args_list[0][1]["dimension"]) + self.assertEqual(spec, client.topn.call_args_list[1][1]["dimension"]) + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_run_query_multiple_groupby(self): + client = Mock() + from_dttm = Mock() + to_dttm = Mock() + from_dttm.replace = Mock(return_value=from_dttm) + to_dttm.replace = Mock(return_value=to_dttm) + from_dttm.isoformat = Mock(return_value="from") + to_dttm.isoformat = Mock(return_value="to") + timezone = "timezone" + from_dttm.tzname = Mock(return_value=timezone) + ds = DruidDatasource(datasource_name="datasource") + metric1 = DruidMetric(metric_name="metric1") + metric2 = DruidMetric(metric_name="metric2") + ds.metrics = [metric1, metric2] + col1 = DruidColumn(column_name="col1") + col2 = DruidColumn(column_name="col2") + ds.columns = [col1, col2] + aggs = [] + post_aggs = ["some_agg"] + ds._metrics_and_post_aggs = Mock(return_value=(aggs, post_aggs)) + columns = ["col1", "col2"] + metrics = ["metric1"] + ds.get_having_filters = Mock(return_value=[]) + client.query_builder = Mock() + client.query_builder.last_query = Mock() + client.query_builder.last_query.query_dict = {"mock": 0} + # no groupby calls client.timeseries + ds.run_query( + metrics, + None, + from_dttm, + to_dttm, + groupby=columns, + client=client, + row_limit=100, + filter=[], + ) + self.assertEqual(0, len(client.topn.call_args_list)) + self.assertEqual(1, len(client.groupby.call_args_list)) + self.assertEqual(0, len(client.timeseries.call_args_list)) + # check that there is no dimensions entry + called_args = client.groupby.call_args_list[0][1] + self.assertIn("dimensions", called_args) + self.assertEqual(["col1", "col2"], called_args["dimensions"]) + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_get_post_agg_returns_correct_agg_type(self): + get_post_agg = DruidDatasource.get_post_agg + # javascript PostAggregators + function = "function(field1, field2) { return field1 + field2; }" + conf = { + "type": "javascript", + "name": "postagg_name", + "fieldNames": ["field1", "field2"], + "function": function, + } + postagg = get_post_agg(conf) + self.assertTrue(isinstance(postagg, models.JavascriptPostAggregator)) + self.assertEqual(postagg.name, "postagg_name") + self.assertEqual(postagg.post_aggregator["type"], "javascript") + self.assertEqual(postagg.post_aggregator["fieldNames"], ["field1", "field2"]) + self.assertEqual(postagg.post_aggregator["name"], "postagg_name") + self.assertEqual(postagg.post_aggregator["function"], function) + # Quantile + conf = {"type": "quantile", "name": "postagg_name", "probability": "0.5"} + postagg = get_post_agg(conf) + self.assertTrue(isinstance(postagg, postaggs.Quantile)) + self.assertEqual(postagg.name, "postagg_name") + self.assertEqual(postagg.post_aggregator["probability"], "0.5") + # Quantiles + conf = { + "type": "quantiles", + "name": "postagg_name", + "probabilities": "0.4,0.5,0.6", + } + postagg = get_post_agg(conf) + self.assertTrue(isinstance(postagg, postaggs.Quantiles)) + self.assertEqual(postagg.name, "postagg_name") + self.assertEqual(postagg.post_aggregator["probabilities"], "0.4,0.5,0.6") + # FieldAccess + conf = {"type": "fieldAccess", "name": "field_name"} + postagg = get_post_agg(conf) + self.assertTrue(isinstance(postagg, postaggs.Field)) + self.assertEqual(postagg.name, "field_name") + # constant + conf = {"type": "constant", "value": 1234, "name": "postagg_name"} + postagg = get_post_agg(conf) + self.assertTrue(isinstance(postagg, postaggs.Const)) + self.assertEqual(postagg.name, "postagg_name") + self.assertEqual(postagg.post_aggregator["value"], 1234) + # hyperUniqueCardinality + conf = {"type": "hyperUniqueCardinality", "name": "unique_name"} + postagg = get_post_agg(conf) + self.assertTrue(isinstance(postagg, postaggs.HyperUniqueCardinality)) + self.assertEqual(postagg.name, "unique_name") + # arithmetic + conf = { + "type": "arithmetic", + "fn": "+", + "fields": ["field1", "field2"], + "name": "postagg_name", + } + postagg = get_post_agg(conf) + self.assertTrue(isinstance(postagg, postaggs.Postaggregator)) + self.assertEqual(postagg.name, "postagg_name") + self.assertEqual(postagg.post_aggregator["fn"], "+") + self.assertEqual(postagg.post_aggregator["fields"], ["field1", "field2"]) + # custom post aggregator + conf = {"type": "custom", "name": "custom_name", "stuff": "more_stuff"} + postagg = get_post_agg(conf) + self.assertTrue(isinstance(postagg, models.CustomPostAggregator)) + self.assertEqual(postagg.name, "custom_name") + self.assertEqual(postagg.post_aggregator["stuff"], "more_stuff") + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_find_postaggs_for_returns_postaggs_and_removes(self): + find_postaggs_for = DruidDatasource.find_postaggs_for + postagg_names = set(["pa2", "pa3", "pa4", "m1", "m2", "m3", "m4"]) + + metrics = {} + for i in range(1, 6): + emplace(metrics, "pa" + str(i), True) + emplace(metrics, "m" + str(i), False) + postagg_list = find_postaggs_for(postagg_names, metrics) + self.assertEqual(3, len(postagg_list)) + self.assertEqual(4, len(postagg_names)) + expected_metrics = ["m1", "m2", "m3", "m4"] + expected_postaggs = set(["pa2", "pa3", "pa4"]) + for postagg in postagg_list: + expected_postaggs.remove(postagg.metric_name) + for metric in expected_metrics: + postagg_names.remove(metric) + self.assertEqual(0, len(expected_postaggs)) + self.assertEqual(0, len(postagg_names)) + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_recursive_get_fields(self): + conf = { + "type": "quantile", + "fieldName": "f1", + "field": { + "type": "custom", + "fields": [ + {"type": "fieldAccess", "fieldName": "f2"}, + {"type": "fieldAccess", "fieldName": "f3"}, + { + "type": "quantiles", + "fieldName": "f4", + "field": {"type": "custom"}, + }, + { + "type": "custom", + "fields": [ + {"type": "fieldAccess", "fieldName": "f5"}, + { + "type": "fieldAccess", + "fieldName": "f2", + "fields": [ + {"type": "fieldAccess", "fieldName": "f3"}, + {"type": "fieldIgnoreMe", "fieldName": "f6"}, + ], + }, + ], + }, + ], + }, + } + fields = DruidDatasource.recursive_get_fields(conf) + expected = set(["f1", "f2", "f3", "f4", "f5"]) + self.assertEqual(5, len(fields)) + for field in fields: + expected.remove(field) + self.assertEqual(0, len(expected)) + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_metrics_and_post_aggs_tree(self): + metrics = ["A", "B", "m1", "m2"] + metrics_dict = {} + for i in range(ord("A"), ord("K") + 1): + emplace(metrics_dict, chr(i), True) + for i in range(1, 10): + emplace(metrics_dict, "m" + str(i), False) + + def depends_on(index, fields): + dependents = fields if isinstance(fields, list) else [fields] + metrics_dict[index].json_obj = {"fieldNames": dependents} + + depends_on("A", ["m1", "D", "C"]) + depends_on("B", ["B", "C", "E", "F", "m3"]) + depends_on("C", ["H", "I"]) + depends_on("D", ["m2", "m5", "G", "C"]) + depends_on("E", ["H", "I", "J"]) + depends_on("F", ["J", "m5"]) + depends_on("G", ["m4", "m7", "m6", "A"]) + depends_on("H", ["A", "m4", "I"]) + depends_on("I", ["H", "K"]) + depends_on("J", "K") + depends_on("K", ["m8", "m9"]) + aggs, postaggs = DruidDatasource.metrics_and_post_aggs(metrics, metrics_dict) + expected_metrics = set(aggs.keys()) + self.assertEqual(9, len(aggs)) + for i in range(1, 10): + expected_metrics.remove("m" + str(i)) + self.assertEqual(0, len(expected_metrics)) + self.assertEqual(11, len(postaggs)) + for i in range(ord("A"), ord("K") + 1): + del postaggs[chr(i)] + self.assertEqual(0, len(postaggs)) + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_metrics_and_post_aggs(self): + """ + Test generation of metrics and post-aggregations from an initial list + of superset metrics (which may include the results of either). This + primarily tests that specifying a post-aggregator metric will also + require the raw aggregation of the associated druid metric column. + """ + metrics_dict = { + "unused_count": DruidMetric( + metric_name="unused_count", + verbose_name="COUNT(*)", + metric_type="count", + json=json.dumps({"type": "count", "name": "unused_count"}), + ), + "some_sum": DruidMetric( + metric_name="some_sum", + verbose_name="SUM(*)", + metric_type="sum", + json=json.dumps({"type": "sum", "name": "sum"}), + ), + "a_histogram": DruidMetric( + metric_name="a_histogram", + verbose_name="APPROXIMATE_HISTOGRAM(*)", + metric_type="approxHistogramFold", + json=json.dumps({"type": "approxHistogramFold", "name": "a_histogram"}), + ), + "aCustomMetric": DruidMetric( + metric_name="aCustomMetric", + verbose_name="MY_AWESOME_METRIC(*)", + metric_type="aCustomType", + json=json.dumps({"type": "customMetric", "name": "aCustomMetric"}), + ), + "quantile_p95": DruidMetric( + metric_name="quantile_p95", + verbose_name="P95(*)", + metric_type="postagg", + json=json.dumps( + { + "type": "quantile", + "probability": 0.95, + "name": "p95", + "fieldName": "a_histogram", + } + ), + ), + "aCustomPostAgg": DruidMetric( + metric_name="aCustomPostAgg", + verbose_name="CUSTOM_POST_AGG(*)", + metric_type="postagg", + json=json.dumps( + { + "type": "customPostAgg", + "name": "aCustomPostAgg", + "field": {"type": "fieldAccess", "fieldName": "aCustomMetric"}, + } + ), + ), + } + + adhoc_metric = { + "expressionType": "SIMPLE", + "column": {"type": "DOUBLE", "column_name": "value"}, + "aggregate": "SUM", + "label": "My Adhoc Metric", + } + + metrics = ["some_sum"] + saved_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs( + metrics, metrics_dict + ) + + assert set(saved_metrics.keys()) == {"some_sum"} + assert post_aggs == {} + + metrics = [adhoc_metric] + saved_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs( + metrics, metrics_dict + ) + + assert set(saved_metrics.keys()) == set([adhoc_metric["label"]]) + assert post_aggs == {} + + metrics = ["some_sum", adhoc_metric] + saved_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs( + metrics, metrics_dict + ) + + assert set(saved_metrics.keys()) == {"some_sum", adhoc_metric["label"]} + assert post_aggs == {} + + metrics = ["quantile_p95"] + saved_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs( + metrics, metrics_dict + ) + + result_postaggs = set(["quantile_p95"]) + assert set(saved_metrics.keys()) == {"a_histogram"} + assert set(post_aggs.keys()) == result_postaggs + + metrics = ["aCustomPostAgg"] + saved_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs( + metrics, metrics_dict + ) + + result_postaggs = set(["aCustomPostAgg"]) + assert set(saved_metrics.keys()) == {"aCustomMetric"} + assert set(post_aggs.keys()) == result_postaggs + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_druid_type_from_adhoc_metric(self): + + druid_type = DruidDatasource.druid_type_from_adhoc_metric( + { + "column": {"type": "DOUBLE", "column_name": "value"}, + "aggregate": "SUM", + "label": "My Adhoc Metric", + } + ) + assert druid_type == "doubleSum" + + druid_type = DruidDatasource.druid_type_from_adhoc_metric( + { + "column": {"type": "LONG", "column_name": "value"}, + "aggregate": "MAX", + "label": "My Adhoc Metric", + } + ) + assert druid_type == "longMax" + + druid_type = DruidDatasource.druid_type_from_adhoc_metric( + { + "column": {"type": "VARCHAR(255)", "column_name": "value"}, + "aggregate": "COUNT", + "label": "My Adhoc Metric", + } + ) + assert druid_type == "count" + + druid_type = DruidDatasource.druid_type_from_adhoc_metric( + { + "column": {"type": "VARCHAR(255)", "column_name": "value"}, + "aggregate": "COUNT_DISTINCT", + "label": "My Adhoc Metric", + } + ) + assert druid_type == "cardinality" + + druid_type = DruidDatasource.druid_type_from_adhoc_metric( + { + "column": {"type": "hyperUnique", "column_name": "value"}, + "aggregate": "COUNT_DISTINCT", + "label": "My Adhoc Metric", + } + ) + assert druid_type == "hyperUnique" + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_run_query_order_by_metrics(self): + client = Mock() + client.query_builder.last_query.query_dict = {"mock": 0} + from_dttm = Mock() + to_dttm = Mock() + ds = DruidDatasource(datasource_name="datasource") + ds.get_having_filters = Mock(return_value=[]) + dim1 = DruidColumn(column_name="dim1") + dim2 = DruidColumn(column_name="dim2") + metrics_dict = { + "count1": DruidMetric( + metric_name="count1", + metric_type="count", + json=json.dumps({"type": "count", "name": "count1"}), + ), + "sum1": DruidMetric( + metric_name="sum1", + metric_type="doubleSum", + json=json.dumps({"type": "doubleSum", "name": "sum1"}), + ), + "sum2": DruidMetric( + metric_name="sum2", + metric_type="doubleSum", + json=json.dumps({"type": "doubleSum", "name": "sum2"}), + ), + "div1": DruidMetric( + metric_name="div1", + metric_type="postagg", + json=json.dumps( + { + "fn": "/", + "type": "arithmetic", + "name": "div1", + "fields": [ + {"fieldName": "sum1", "type": "fieldAccess"}, + {"fieldName": "sum2", "type": "fieldAccess"}, + ], + } + ), + ), + } + ds.columns = [dim1, dim2] + ds.metrics = list(metrics_dict.values()) + + columns = ["dim1"] + metrics = ["count1"] + granularity = "all" + # get the counts of the top 5 'dim1's, order by 'sum1' + ds.run_query( + metrics, + granularity, + from_dttm, + to_dttm, + groupby=columns, + timeseries_limit=5, + timeseries_limit_metric="sum1", + client=client, + order_desc=True, + filter=[], + ) + qry_obj = client.topn.call_args_list[0][1] + self.assertEqual("dim1", qry_obj["dimension"]) + self.assertEqual("sum1", qry_obj["metric"]) + aggregations = qry_obj["aggregations"] + post_aggregations = qry_obj["post_aggregations"] + self.assertEqual({"count1", "sum1"}, set(aggregations.keys())) + self.assertEqual(set(), set(post_aggregations.keys())) + + # get the counts of the top 5 'dim1's, order by 'div1' + ds.run_query( + metrics, + granularity, + from_dttm, + to_dttm, + groupby=columns, + timeseries_limit=5, + timeseries_limit_metric="div1", + client=client, + order_desc=True, + filter=[], + ) + qry_obj = client.topn.call_args_list[1][1] + self.assertEqual("dim1", qry_obj["dimension"]) + self.assertEqual("div1", qry_obj["metric"]) + aggregations = qry_obj["aggregations"] + post_aggregations = qry_obj["post_aggregations"] + self.assertEqual({"count1", "sum1", "sum2"}, set(aggregations.keys())) + self.assertEqual({"div1"}, set(post_aggregations.keys())) + + columns = ["dim1", "dim2"] + # get the counts of the top 5 ['dim1', 'dim2']s, order by 'sum1' + ds.run_query( + metrics, + granularity, + from_dttm, + to_dttm, + groupby=columns, + timeseries_limit=5, + timeseries_limit_metric="sum1", + client=client, + order_desc=True, + filter=[], + ) + qry_obj = client.groupby.call_args_list[0][1] + self.assertEqual({"dim1", "dim2"}, set(qry_obj["dimensions"])) + self.assertEqual("sum1", qry_obj["limit_spec"]["columns"][0]["dimension"]) + aggregations = qry_obj["aggregations"] + post_aggregations = qry_obj["post_aggregations"] + self.assertEqual({"count1", "sum1"}, set(aggregations.keys())) + self.assertEqual(set(), set(post_aggregations.keys())) + + # get the counts of the top 5 ['dim1', 'dim2']s, order by 'div1' + ds.run_query( + metrics, + granularity, + from_dttm, + to_dttm, + groupby=columns, + timeseries_limit=5, + timeseries_limit_metric="div1", + client=client, + order_desc=True, + filter=[], + ) + qry_obj = client.groupby.call_args_list[1][1] + self.assertEqual({"dim1", "dim2"}, set(qry_obj["dimensions"])) + self.assertEqual("div1", qry_obj["limit_spec"]["columns"][0]["dimension"]) + aggregations = qry_obj["aggregations"] + post_aggregations = qry_obj["post_aggregations"] + self.assertEqual({"count1", "sum1", "sum2"}, set(aggregations.keys())) + self.assertEqual({"div1"}, set(post_aggregations.keys())) + + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_get_aggregations(self): + ds = DruidDatasource(datasource_name="datasource") + metrics_dict = { + "sum1": DruidMetric( + metric_name="sum1", + metric_type="doubleSum", + json=json.dumps({"type": "doubleSum", "name": "sum1"}), + ), + "sum2": DruidMetric( + metric_name="sum2", + metric_type="doubleSum", + json=json.dumps({"type": "doubleSum", "name": "sum2"}), + ), + "div1": DruidMetric( + metric_name="div1", + metric_type="postagg", + json=json.dumps( + { + "fn": "/", + "type": "arithmetic", + "name": "div1", + "fields": [ + {"fieldName": "sum1", "type": "fieldAccess"}, + {"fieldName": "sum2", "type": "fieldAccess"}, + ], + } + ), + ), + } + metric_names = ["sum1", "sum2"] + aggs = ds.get_aggregations(metrics_dict, metric_names) + expected_agg = {name: metrics_dict[name].json_obj for name in metric_names} + self.assertEqual(expected_agg, aggs) + + metric_names = ["sum1", "col1"] + self.assertRaises( + SupersetException, ds.get_aggregations, metrics_dict, metric_names + ) + + metric_names = ["sum1", "div1"] + self.assertRaises( + SupersetException, ds.get_aggregations, metrics_dict, metric_names + )