From 794eff9471a66fe151276fad3d0f5841ee120f33 Mon Sep 17 00:00:00 2001 From: Elizabeth Thompson Date: Wed, 4 May 2022 12:48:48 -0700 Subject: [PATCH] chore: remove druid datasource from the config (#19770) * remove druid datasource from the config * remove config related references to DruidDatasource * Update __init__.py * Update __init__.py * Update manager.py * remove config related references to DruidDatasource * raise if instance type is not valid --- UPDATING.md | 1 + .../visualizations/FilterBox/controlPanel.jsx | 59 - superset/cli/update.py | 34 - superset/config.py | 22 +- superset/connectors/druid/__init__.py | 16 - superset/connectors/druid/models.py | 1723 ----------------- superset/connectors/druid/views.py | 445 ----- superset/dashboards/commands/importers/v0.py | 19 +- superset/datasets/commands/importers/v0.py | 64 +- superset/initialization/__init__.py | 69 +- superset/models/dashboard.py | 5 - superset/security/manager.py | 6 +- superset/utils/dict_import_export.py | 19 - superset/views/base.py | 6 +- superset/views/core.py | 55 - tests/integration_tests/access_tests.py | 87 - tests/integration_tests/base_tests.py | 31 +- tests/integration_tests/core_tests.py | 1 - .../dict_import_export_tests.py | 154 +- tests/integration_tests/druid_func_tests.py | 1152 ----------- .../druid_func_tests_sip38.py | 1157 ----------- tests/integration_tests/druid_tests.py | 668 ------- .../integration_tests/import_export_tests.py | 105 +- tests/integration_tests/security_tests.py | 111 -- 24 files changed, 21 insertions(+), 5988 deletions(-) delete mode 100644 superset/connectors/druid/__init__.py delete mode 100644 superset/connectors/druid/models.py delete mode 100644 superset/connectors/druid/views.py delete mode 100644 tests/integration_tests/druid_func_tests.py delete mode 100644 tests/integration_tests/druid_func_tests_sip38.py delete mode 100644 tests/integration_tests/druid_tests.py diff --git a/UPDATING.md b/UPDATING.md index e6cce388670d..94f2103a3cce 100644 --- a/UPDATING.md +++ b/UPDATING.md @@ -30,6 +30,7 @@ assists people when migrating to a new version. ### Breaking Changes +- [19770](https://github.com/apache/superset/pull/19770): As per SIPs 11 and 68, the native NoSQL Druid connector is deprecated and has been removed. Druid is still supported through SQLAlchemy via pydruid. The config keys `DRUID_IS_ACTIVE` and `DRUID_METADATA_LINKS_ENABLED` have also been removed. - [19274](https://github.com/apache/superset/pull/19274): The `PUBLIC_ROLE_LIKE_GAMMA` config key has been removed, set `PUBLIC_ROLE_LIKE = "Gamma"` to have the same functionality. - [19273](https://github.com/apache/superset/pull/19273): The `SUPERSET_CELERY_WORKERS` and `SUPERSET_WORKERS` config keys has been removed. Configure Celery directly using `CELERY_CONFIG` on Superset. - [19262](https://github.com/apache/superset/pull/19262): Per [SIP-11](https://github.com/apache/superset/issues/6032) and [SIP-68](https://github.com/apache/superset/issues/14909) the native NoSQL Druid connector is deprecated and will no longer be supported. Druid SQL is still [supported](https://superset.apache.org/docs/databases/druid). diff --git a/superset-frontend/src/visualizations/FilterBox/controlPanel.jsx b/superset-frontend/src/visualizations/FilterBox/controlPanel.jsx index 34a814a43efe..29906843da51 100644 --- a/superset-frontend/src/visualizations/FilterBox/controlPanel.jsx +++ b/superset-frontend/src/visualizations/FilterBox/controlPanel.jsx @@ -20,36 +20,6 @@ import React from 'react'; import { t } from '@superset-ui/core'; import { sections } from '@superset-ui/chart-controls'; -const appContainer = document.getElementById('app'); -const bootstrapData = JSON.parse(appContainer.getAttribute('data-bootstrap')); -const druidIsActive = !!bootstrapData?.common?.conf?.DRUID_IS_ACTIVE; -const druidSection = druidIsActive - ? [ - [ - { - name: 'show_druid_time_granularity', - config: { - type: 'CheckboxControl', - label: t('Show Druid granularity dropdown'), - default: false, - description: t('Check to include Druid granularity dropdown'), - }, - }, - ], - [ - { - name: 'show_druid_time_origin', - config: { - type: 'CheckboxControl', - label: t('Show Druid time origin'), - default: false, - description: t('Check to include time origin dropdown'), - }, - }, - ], - ] - : []; - export default { controlPanelSections: [ sections.legacyTimeseriesTime, @@ -96,35 +66,6 @@ export default { }, }, ], - [ - { - name: 'show_sqla_time_granularity', - config: { - type: 'CheckboxControl', - label: druidIsActive - ? t('Show SQL time grain dropdown') - : t('Show time grain dropdown'), - default: false, - description: druidIsActive - ? t('Check to include SQL time grain dropdown') - : t('Check to include time grain dropdown'), - }, - }, - ], - [ - { - name: 'show_sqla_time_column', - config: { - type: 'CheckboxControl', - label: druidIsActive - ? t('Show SQL time column') - : t('Show time column'), - default: false, - description: t('Check to include time column dropdown'), - }, - }, - ], - ...druidSection, ['adhoc_filters'], ], }, diff --git a/superset/cli/update.py b/superset/cli/update.py index f7b4edd1269b..d31460c5e6c8 100755 --- a/superset/cli/update.py +++ b/superset/cli/update.py @@ -18,7 +18,6 @@ import logging import os import sys -from datetime import datetime from typing import Optional import click @@ -53,39 +52,6 @@ def set_database_uri(database_name: str, uri: str, skip_create: bool) -> None: database_utils.get_or_create_db(database_name, uri, not skip_create) -@click.command() -@with_appcontext -@click.option( - "--datasource", - "-d", - help="Specify which datasource name to load, if " - "omitted, all datasources will be refreshed", -) -@click.option( - "--merge", - "-m", - is_flag=True, - default=False, - help="Specify using 'merge' property during operation. " "Default value is False.", -) -def refresh_druid(datasource: str, merge: bool) -> None: - """Refresh druid datasources""" - # pylint: disable=import-outside-toplevel - from superset.connectors.druid.models import DruidCluster - - session = db.session() - - for cluster in session.query(DruidCluster).all(): - try: - cluster.refresh_datasources(datasource_name=datasource, merge_flag=merge) - except Exception as ex: # pylint: disable=broad-except - print("Error while processing cluster '{}'\n{}".format(cluster, str(ex))) - logger.exception(ex) - cluster.metadata_last_refreshed = datetime.now() - print("Refreshed metadata from cluster " "[" + cluster.cluster_name + "]") - session.commit() - - @click.command() @with_appcontext def update_datasources_cache() -> None: diff --git a/superset/config.py b/superset/config.py index 793f2ff00b83..3ca777be62eb 100644 --- a/superset/config.py +++ b/superset/config.py @@ -257,16 +257,6 @@ def _try_json_readsha(filepath: str, length: int) -> Optional[str]: DRUID_TZ = tz.tzutc() DRUID_ANALYSIS_TYPES = ["cardinality"] -# Legacy Druid NoSQL (native) connector -# Druid supports a SQL interface in its newer versions. -# Setting this flag to True enables the deprecated, API-based Druid -# connector. This feature may be removed at a future date. -DRUID_IS_ACTIVE = False - -# If Druid is active whether to include the links to scan/refresh Druid datasources. -# This should be disabled if you are trying to wean yourself off of the Druid NoSQL -# connector. -DRUID_METADATA_LINKS_ENABLED = True # ---------------------------------------------------- # AUTHENTICATION CONFIG @@ -645,19 +635,12 @@ def _try_json_readsha(filepath: str, length: int) -> Optional[str]: VIZ_TYPE_DENYLIST: List[str] = [] -# --------------------------------------------------- -# List of data sources not to be refreshed in druid cluster -# --------------------------------------------------- - -DRUID_DATA_SOURCE_DENYLIST: List[str] = [] - # -------------------------------------------------- # Modules, datasources and middleware to be registered # -------------------------------------------------- DEFAULT_MODULE_DS_MAP = OrderedDict( [ ("superset.connectors.sqla.models", ["SqlaTable"]), - ("superset.connectors.druid.models", ["DruidDatasource"]), ] ) ADDITIONAL_MODULE_DS_MAP: Dict[str, List[str]] = {} @@ -983,8 +966,11 @@ def CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC( # pylint: disable=invalid-name # Provide a callable that receives a tracking_url and returns another # URL. This is used to translate internal Hadoop job tracker URL # into a proxied one + + TRACKING_URL_TRANSFORMER = lambda x: x + # Interval between consecutive polls when using Hive Engine HIVE_POLL_INTERVAL = int(timedelta(seconds=5).total_seconds()) @@ -1202,8 +1188,10 @@ def SQL_QUERY_MUTATOR( # pylint: disable=invalid-name,unused-argument # to allow mutating the object with this callback. # This can be used to set any properties of the object based on naming # conventions and such. You can find examples in the tests. + SQLA_TABLE_MUTATOR = lambda table: table + # Global async query config options. # Requires GLOBAL_ASYNC_QUERIES feature flag to be enabled. GLOBAL_ASYNC_QUERIES_REDIS_CONFIG = { diff --git a/superset/connectors/druid/__init__.py b/superset/connectors/druid/__init__.py deleted file mode 100644 index 13a83393a912..000000000000 --- a/superset/connectors/druid/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py deleted file mode 100644 index f64bb1882da5..000000000000 --- a/superset/connectors/druid/models.py +++ /dev/null @@ -1,1723 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: skip-file -import json -import logging -import os -import pandas as pd -import re -import sqlalchemy as sa -from collections import OrderedDict -from copy import deepcopy -from datetime import datetime, timedelta -from dateutil.parser import parse as dparse -from distutils.version import LooseVersion -from flask import escape, Markup -from flask_appbuilder import Model -from flask_appbuilder.models.decorators import renders -from flask_appbuilder.security.sqla.models import User -from flask_babel import lazy_gettext as _ -from multiprocessing.pool import ThreadPool -from sqlalchemy import ( - Boolean, - Column, - DateTime, - ForeignKey, - Integer, - String, - Table, - Text, - UniqueConstraint, - update, -) -from sqlalchemy.engine.base import Connection -from sqlalchemy.ext.hybrid import hybrid_property -from sqlalchemy.orm import backref, relationship, Session -from sqlalchemy.orm.mapper import Mapper -from sqlalchemy.sql import expression -from typing import Any, cast, Dict, Iterable, List, Optional, Set, Tuple, Union - -from superset import conf, db, security_manager -from superset.connectors.base.models import BaseColumn, BaseDatasource, BaseMetric -from superset.constants import NULL_STRING -from superset.exceptions import SupersetException -from superset.extensions import encrypted_field_factory -from superset.models.core import Database -from superset.models.helpers import AuditMixinNullable, ImportExportMixin, QueryResult -from superset.superset_typing import ( - AdhocMetric, - AdhocMetricColumn, - FilterValues, - Granularity, - Metric, - QueryObjectDict, -) -from superset.utils import core as utils -from superset.utils.date_parser import parse_human_datetime, parse_human_timedelta -from superset.utils.memoized import memoized - -try: - import requests - from pydruid.client import PyDruid - from pydruid.utils.aggregators import count - from pydruid.utils.dimensions import ( - MapLookupExtraction, - RegexExtraction, - RegisteredLookupExtraction, - TimeFormatExtraction, - ) - from pydruid.utils.filters import Bound, Dimension, Filter - from pydruid.utils.having import Aggregation, Having - from pydruid.utils.postaggregator import ( - Const, - Field, - HyperUniqueCardinality, - Postaggregator, - Quantile, - Quantiles, - ) -except ImportError: - pass - -try: - from superset.utils.core import ( - DimSelector, - DTTM_ALIAS, - FilterOperator, - flasher, - get_metric_name, - ) -except ImportError: - pass - -DRUID_TZ = conf.get("DRUID_TZ") -POST_AGG_TYPE = "postagg" -metadata = Model.metadata # pylint: disable=no-member -logger = logging.getLogger(__name__) - -try: - # Postaggregator might not have been imported. - class JavascriptPostAggregator(Postaggregator): - def __init__(self, name: str, field_names: List[str], function: str) -> None: - self.post_aggregator = { - "type": "javascript", - "fieldNames": field_names, - "name": name, - "function": function, - } - self.name = name - - class CustomPostAggregator(Postaggregator): - """A way to allow users to specify completely custom PostAggregators""" - - def __init__(self, name: str, post_aggregator: Dict[str, Any]) -> None: - self.name = name - self.post_aggregator = post_aggregator - -except NameError: - pass - -# Function wrapper because bound methods cannot -# be passed to processes -def _fetch_metadata_for(datasource: "DruidDatasource") -> Optional[Dict[str, Any]]: - return datasource.latest_metadata() - - -class DruidCluster(Model, AuditMixinNullable, ImportExportMixin): - - """ORM object referencing the Druid clusters""" - - __tablename__ = "clusters" - type = "druid" - - id = Column(Integer, primary_key=True) - verbose_name = Column(String(250), unique=True) - # short unique name, used in permissions - cluster_name = Column(String(250), unique=True, nullable=False) - broker_host = Column(String(255)) - broker_port = Column(Integer, default=8082) - broker_endpoint = Column(String(255), default="druid/v2") - metadata_last_refreshed = Column(DateTime) - cache_timeout = Column(Integer) - broker_user = Column(String(255)) - broker_pass = Column(encrypted_field_factory.create(String(255))) - - export_fields = [ - "cluster_name", - "broker_host", - "broker_port", - "broker_endpoint", - "cache_timeout", - "broker_user", - ] - update_from_object_fields = export_fields - export_children = ["datasources"] - - def __repr__(self) -> str: - return self.verbose_name if self.verbose_name else self.cluster_name - - def __html__(self) -> str: - return self.__repr__() - - @property - def data(self) -> Dict[str, Any]: - return {"id": self.id, "name": self.cluster_name, "backend": "druid"} - - @staticmethod - def get_base_url(host: str, port: int) -> str: - if not re.match("http(s)?://", host): - host = "http://" + host - - url = "{0}:{1}".format(host, port) if port else host - return url - - def get_base_broker_url(self) -> str: - base_url = self.get_base_url(self.broker_host, self.broker_port) - return f"{base_url}/{self.broker_endpoint}" - - def get_pydruid_client(self) -> "PyDruid": - cli = PyDruid( - self.get_base_url(self.broker_host, self.broker_port), self.broker_endpoint - ) - if self.broker_user and self.broker_pass: - cli.set_basic_auth_credentials(self.broker_user, self.broker_pass) - return cli - - def get_datasources(self) -> List[str]: - endpoint = self.get_base_broker_url() + "/datasources" - auth = requests.auth.HTTPBasicAuth(self.broker_user, self.broker_pass) - return json.loads(requests.get(endpoint, auth=auth).text) - - def get_druid_version(self) -> str: - endpoint = self.get_base_url(self.broker_host, self.broker_port) + "/status" - auth = requests.auth.HTTPBasicAuth(self.broker_user, self.broker_pass) - return json.loads(requests.get(endpoint, auth=auth).text)["version"] - - @property # type: ignore - @memoized - def druid_version(self) -> str: - return self.get_druid_version() - - def refresh_datasources( - self, - datasource_name: Optional[str] = None, - merge_flag: bool = True, - refresh_all: bool = True, - ) -> None: - """Refresh metadata of all datasources in the cluster - If ``datasource_name`` is specified, only that datasource is updated - """ - ds_list = self.get_datasources() - denylist = conf.get("DRUID_DATA_SOURCE_DENYLIST", []) - ds_refresh: List[str] = [] - if not datasource_name: - ds_refresh = list(filter(lambda ds: ds not in denylist, ds_list)) - elif datasource_name not in denylist and datasource_name in ds_list: - ds_refresh.append(datasource_name) - else: - return - self.refresh(ds_refresh, merge_flag, refresh_all) - - def refresh( - self, datasource_names: List[str], merge_flag: bool, refresh_all: bool - ) -> None: - """ - Fetches metadata for the specified datasources and - merges to the Superset database - """ - session = db.session - ds_list = ( - session.query(DruidDatasource) - .filter(DruidDatasource.cluster_id == self.id) - .filter(DruidDatasource.datasource_name.in_(datasource_names)) - ) - ds_map = {ds.name: ds for ds in ds_list} - for ds_name in datasource_names: - datasource = ds_map.get(ds_name, None) - if not datasource: - datasource = DruidDatasource(datasource_name=ds_name) - with session.no_autoflush: - session.add(datasource) - flasher(_("Adding new datasource [{}]").format(ds_name), "success") - ds_map[ds_name] = datasource - elif refresh_all: - flasher(_("Refreshing datasource [{}]").format(ds_name), "info") - else: - del ds_map[ds_name] - continue - datasource.cluster = self - datasource.merge_flag = merge_flag - session.flush() - - # Prepare multithreaded executation - pool = ThreadPool() - ds_refresh = list(ds_map.values()) - metadata = pool.map(_fetch_metadata_for, ds_refresh) - pool.close() - pool.join() - - for i in range(0, len(ds_refresh)): - datasource = ds_refresh[i] - cols = metadata[i] - if cols: - col_objs_list = ( - session.query(DruidColumn) - .filter(DruidColumn.datasource_id == datasource.id) - .filter(DruidColumn.column_name.in_(cols.keys())) - ) - col_objs = {col.column_name: col for col in col_objs_list} - for col in cols: - if col == "__time": # skip the time column - continue - col_obj = col_objs.get(col) - if not col_obj: - col_obj = DruidColumn( - datasource_id=datasource.id, column_name=col - ) - with session.no_autoflush: - session.add(col_obj) - col_obj.type = cols[col]["type"] - col_obj.datasource = datasource - if col_obj.type == "STRING": - col_obj.groupby = True - col_obj.filterable = True - datasource.refresh_metrics() - session.commit() - - @hybrid_property - def perm(self) -> str: - return f"[{self.cluster_name}].(id:{self.id})" - - @perm.expression # type: ignore - def perm(cls) -> str: # pylint: disable=no-self-argument - return "[" + cls.cluster_name + "].(id:" + expression.cast(cls.id, String) + ")" - - def get_perm(self) -> str: - return self.perm # type: ignore - - @property - def name(self) -> str: - return self.verbose_name or self.cluster_name - - @property - def unique_name(self) -> str: - return self.verbose_name or self.cluster_name - - -sa.event.listen(DruidCluster, "after_insert", security_manager.set_perm) -sa.event.listen(DruidCluster, "after_update", security_manager.set_perm) - - -class DruidColumn(Model, BaseColumn): - """ORM model for storing Druid datasource column metadata""" - - __tablename__ = "columns" - __table_args__ = (UniqueConstraint("column_name", "datasource_id"),) - - datasource_id = Column(Integer, ForeignKey("datasources.id")) - # Setting enable_typechecks=False disables polymorphic inheritance. - datasource = relationship( - "DruidDatasource", - backref=backref("columns", cascade="all, delete-orphan"), - enable_typechecks=False, - ) - dimension_spec_json = Column(Text) - - export_fields = [ - "datasource_id", - "column_name", - "is_active", - "type", - "groupby", - "filterable", - "description", - "dimension_spec_json", - "verbose_name", - ] - update_from_object_fields = export_fields - export_parent = "datasource" - - def __repr__(self) -> str: - return self.column_name or str(self.id) - - @property - def expression(self) -> str: - return self.dimension_spec_json - - @property - def dimension_spec(self) -> Optional[Dict[str, Any]]: - if self.dimension_spec_json: - return json.loads(self.dimension_spec_json) - return None - - def get_metrics(self) -> Dict[str, "DruidMetric"]: - metrics = { - "count": DruidMetric( - metric_name="count", - verbose_name="COUNT(*)", - metric_type="count", - json=json.dumps({"type": "count", "name": "count"}), - ) - } - return metrics - - def refresh_metrics(self) -> None: - """Refresh metrics based on the column metadata""" - metrics = self.get_metrics() - dbmetrics = ( - db.session.query(DruidMetric) - .filter(DruidMetric.datasource_id == self.datasource_id) - .filter(DruidMetric.metric_name.in_(metrics.keys())) - ) - dbmetrics = {metric.metric_name: metric for metric in dbmetrics} - for metric in metrics.values(): - dbmetric = dbmetrics.get(metric.metric_name) - if dbmetric: - for attr in ["json", "metric_type"]: - setattr(dbmetric, attr, getattr(metric, attr)) - else: - with db.session.no_autoflush: - metric.datasource_id = self.datasource_id - db.session.add(metric) - - -class DruidMetric(Model, BaseMetric): - - """ORM object referencing Druid metrics for a datasource""" - - __tablename__ = "metrics" - __table_args__ = (UniqueConstraint("metric_name", "datasource_id"),) - datasource_id = Column(Integer, ForeignKey("datasources.id")) - - # Setting enable_typechecks=False disables polymorphic inheritance. - datasource = relationship( - "DruidDatasource", - backref=backref("metrics", cascade="all, delete-orphan"), - enable_typechecks=False, - ) - json = Column(Text, nullable=False) - - export_fields = [ - "metric_name", - "verbose_name", - "metric_type", - "datasource_id", - "json", - "description", - "d3format", - "warning_text", - ] - update_from_object_fields = export_fields - export_parent = "datasource" - - @property - def expression(self) -> Column: - return self.json - - @property - def json_obj(self) -> Dict[str, Any]: - try: - obj = json.loads(self.json) - except Exception: - obj = {} - return obj - - @property - def perm(self) -> Optional[str]: - return ( - ("{parent_name}.[{obj.metric_name}](id:{obj.id})").format( - obj=self, parent_name=self.datasource.full_name - ) - if self.datasource - else None - ) - - def get_perm(self) -> Optional[str]: - return self.perm - - -druiddatasource_user = Table( - "druiddatasource_user", - metadata, - Column("id", Integer, primary_key=True), - Column("user_id", Integer, ForeignKey("ab_user.id")), - Column("datasource_id", Integer, ForeignKey("datasources.id")), -) - - -class DruidDatasource(Model, BaseDatasource): - - """ORM object referencing Druid datasources (tables)""" - - __tablename__ = "datasources" - __table_args__ = (UniqueConstraint("datasource_name", "cluster_id"),) - - type = "druid" - query_language = "json" - cluster_class = DruidCluster - columns: List[DruidColumn] = [] - metrics: List[DruidMetric] = [] - metric_class = DruidMetric - column_class = DruidColumn - owner_class = security_manager.user_model - - baselink = "druiddatasourcemodelview" - - # Columns - datasource_name = Column(String(255), nullable=False) - is_hidden = Column(Boolean, default=False) - filter_select_enabled = Column(Boolean, default=True) # override default - fetch_values_from = Column(String(100)) - cluster_id = Column(Integer, ForeignKey("clusters.id"), nullable=False) - cluster = relationship( - "DruidCluster", backref="datasources", foreign_keys=[cluster_id] - ) - owners = relationship( - owner_class, secondary=druiddatasource_user, backref="druiddatasources" - ) - - export_fields = [ - "datasource_name", - "is_hidden", - "description", - "default_endpoint", - "cluster_id", - "offset", - "cache_timeout", - "params", - "filter_select_enabled", - ] - update_from_object_fields = export_fields - - export_parent = "cluster" - export_children = ["columns", "metrics"] - - @property - def cluster_name(self) -> str: - cluster = ( - self.cluster - or db.session.query(DruidCluster).filter_by(id=self.cluster_id).one() - ) - return cluster.cluster_name - - @property - def database(self) -> DruidCluster: - return self.cluster - - @property - def connection(self) -> str: - return str(self.database) - - @property - def num_cols(self) -> List[str]: - return [c.column_name for c in self.columns if c.is_numeric] - - @property - def name(self) -> str: - return self.datasource_name - - @property - def datasource_type(self) -> str: - return self.type - - @property - def schema(self) -> Optional[str]: - ds_name = self.datasource_name or "" - name_pieces = ds_name.split(".") - if len(name_pieces) > 1: - return name_pieces[0] - else: - return None - - def get_schema_perm(self) -> Optional[str]: - """Returns schema permission if present, cluster one otherwise.""" - return security_manager.get_schema_perm(self.cluster, self.schema) - - def get_perm(self) -> str: - return ("[{obj.cluster_name}].[{obj.datasource_name}]" "(id:{obj.id})").format( - obj=self - ) - - def update_from_object(self, obj: Dict[str, Any]) -> None: - raise NotImplementedError() - - @property - def link(self) -> Markup: - name = escape(self.datasource_name) - return Markup(f'{name}') - - @property - def full_name(self) -> str: - return utils.get_datasource_full_name(self.cluster_name, self.datasource_name) - - @property - def time_column_grains(self) -> Dict[str, List[str]]: - return { - "time_columns": [ - "all", - "5 seconds", - "30 seconds", - "1 minute", - "5 minutes", - "30 minutes", - "1 hour", - "6 hour", - "1 day", - "7 days", - "week", - "week_starting_sunday", - "week_ending_saturday", - "month", - "quarter", - "year", - ], - "time_grains": ["now"], - } - - def __repr__(self) -> str: - return self.datasource_name - - @renders("datasource_name") - def datasource_link(self) -> str: - prefix = os.environ["APP_PREFIX"] - url = f"{prefix}/superset/explore/{self.type}/{self.id}/" - name = escape(self.datasource_name) - return Markup(f'{name}') - - def get_metric_obj(self, metric_name: str) -> Dict[str, Any]: - return [m.json_obj for m in self.metrics if m.metric_name == metric_name][0] - - def latest_metadata(self) -> Optional[Dict[str, Any]]: - """Returns segment metadata from the latest segment""" - logger.info("Syncing datasource [{}]".format(self.datasource_name)) - client = self.cluster.get_pydruid_client() - try: - results = client.time_boundary(datasource=self.datasource_name) - except IOError: - results = None - if results: - max_time = results[0]["result"]["maxTime"] - max_time = dparse(max_time) - else: - max_time = datetime.now() - # Query segmentMetadata for 7 days back. However, due to a bug, - # we need to set this interval to more than 1 day ago to exclude - # realtime segments, which triggered a bug (fixed in druid 0.8.2). - # https://groups.google.com/forum/#!topic/druid-user/gVCqqspHqOQ - lbound = (max_time - timedelta(days=7)).isoformat() - if LooseVersion(self.cluster.druid_version) < LooseVersion("0.8.2"): - rbound = (max_time - timedelta(1)).isoformat() - else: - rbound = max_time.isoformat() - segment_metadata = None - try: - segment_metadata = client.segment_metadata( - datasource=self.datasource_name, - intervals=lbound + "/" + rbound, - merge=self.merge_flag, - analysisTypes=[], - ) - except Exception as ex: - logger.warning("Failed first attempt to get latest segment") - logger.exception(ex) - if not segment_metadata: - # if no segments in the past 7 days, look at all segments - lbound = datetime(1901, 1, 1).isoformat()[:10] - if LooseVersion(self.cluster.druid_version) < LooseVersion("0.8.2"): - rbound = datetime.now().isoformat() - else: - rbound = datetime(2050, 1, 1).isoformat()[:10] - try: - segment_metadata = client.segment_metadata( - datasource=self.datasource_name, - intervals=lbound + "/" + rbound, - merge=self.merge_flag, - analysisTypes=[], - ) - except Exception as ex: - logger.warning("Failed 2nd attempt to get latest segment") - logger.exception(ex) - if segment_metadata: - return segment_metadata[-1]["columns"] - return None - - def refresh_metrics(self) -> None: - for col in self.columns: - col.refresh_metrics() - - @classmethod - def sync_to_db_from_config( - cls, - druid_config: Dict[str, Any], - user: User, - cluster: DruidCluster, - refresh: bool = True, - ) -> None: - """Merges the ds config from druid_config into one stored in the db.""" - session = db.session - datasource = ( - session.query(cls).filter_by(datasource_name=druid_config["name"]).first() - ) - # Create a new datasource. - if not datasource: - datasource = cls( - datasource_name=druid_config["name"], - cluster=cluster, - owners=[user], - changed_by_fk=user.id, - created_by_fk=user.id, - ) - session.add(datasource) - elif not refresh: - return - - dimensions = druid_config["dimensions"] - col_objs = ( - session.query(DruidColumn) - .filter(DruidColumn.datasource_id == datasource.id) - .filter(DruidColumn.column_name.in_(dimensions)) - ) - col_objs = {col.column_name: col for col in col_objs} - for dim in dimensions: - col_obj = col_objs.get(dim, None) - if not col_obj: - col_obj = DruidColumn( - datasource_id=datasource.id, - column_name=dim, - groupby=True, - filterable=True, - # TODO: fetch type from Hive. - type="STRING", - datasource=datasource, - ) - session.add(col_obj) - # Import Druid metrics - metric_objs = ( - session.query(DruidMetric) - .filter(DruidMetric.datasource_id == datasource.id) - .filter( - DruidMetric.metric_name.in_( - spec["name"] for spec in druid_config["metrics_spec"] - ) - ) - ) - metric_objs = {metric.metric_name: metric for metric in metric_objs} - for metric_spec in druid_config["metrics_spec"]: - metric_name = metric_spec["name"] - metric_type = metric_spec["type"] - metric_json = json.dumps(metric_spec) - - if metric_type == "count": - metric_type = "longSum" - metric_json = json.dumps( - {"type": "longSum", "name": metric_name, "fieldName": metric_name} - ) - - metric_obj = metric_objs.get(metric_name, None) - if not metric_obj: - metric_obj = DruidMetric( - metric_name=metric_name, - metric_type=metric_type, - verbose_name="%s(%s)" % (metric_type, metric_name), - datasource=datasource, - json=metric_json, - description=( - "Imported from the airolap config dir for %s" - % druid_config["name"] - ), - ) - session.add(metric_obj) - session.commit() - - @staticmethod - def time_offset(granularity: Granularity) -> int: - if granularity == "week_ending_saturday": - return 6 * 24 * 3600 * 1000 # 6 days - return 0 - - @classmethod - def get_datasource_by_name( - cls, session: Session, datasource_name: str, schema: str, database_name: str - ) -> Optional["DruidDatasource"]: - query = ( - session.query(cls) - .join(DruidCluster) - .filter(cls.datasource_name == datasource_name) - .filter(DruidCluster.cluster_name == database_name) - ) - return query.first() - - # uses https://en.wikipedia.org/wiki/ISO_8601 - # http://druid.io/docs/0.8.0/querying/granularities.html - # TODO: pass origin from the UI - @staticmethod - def granularity( - period_name: str, timezone: Optional[str] = None, origin: Optional[str] = None - ) -> Union[Dict[str, str], str]: - if not period_name or period_name == "all": - return "all" - iso_8601_dict = { - "5 seconds": "PT5S", - "30 seconds": "PT30S", - "1 minute": "PT1M", - "5 minutes": "PT5M", - "30 minutes": "PT30M", - "1 hour": "PT1H", - "6 hour": "PT6H", - "one day": "P1D", - "1 day": "P1D", - "7 days": "P7D", - "week": "P1W", - "week_starting_sunday": "P1W", - "week_ending_saturday": "P1W", - "month": "P1M", - "quarter": "P3M", - "year": "P1Y", - } - - granularity = {"type": "period"} - if timezone: - granularity["timeZone"] = timezone - - if origin: - dttm = parse_human_datetime(origin) - assert dttm - granularity["origin"] = dttm.isoformat() - - if period_name in iso_8601_dict: - granularity["period"] = iso_8601_dict[period_name] - if period_name in ("week_ending_saturday", "week_starting_sunday"): - # use Sunday as start of the week - granularity["origin"] = "2016-01-03T00:00:00" - elif not isinstance(period_name, str): - granularity["type"] = "duration" - granularity["duration"] = period_name - elif period_name.startswith("P"): - # identify if the string is the iso_8601 period - granularity["period"] = period_name - else: - granularity["type"] = "duration" - granularity["duration"] = ( - parse_human_timedelta(period_name).total_seconds() # type: ignore - * 1000 - ) - return granularity - - @staticmethod - def get_post_agg(mconf: Dict[str, Any]) -> "Postaggregator": - """ - For a metric specified as `postagg` returns the - kind of post aggregation for pydruid. - """ - if mconf.get("type") == "javascript": - return JavascriptPostAggregator( - name=mconf.get("name", ""), - field_names=mconf.get("fieldNames", []), - function=mconf.get("function", ""), - ) - elif mconf.get("type") == "quantile": - return Quantile(mconf.get("name", ""), mconf.get("probability", "")) - elif mconf.get("type") == "quantiles": - return Quantiles(mconf.get("name", ""), mconf.get("probabilities", "")) - elif mconf.get("type") == "fieldAccess": - return Field(mconf.get("name")) - elif mconf.get("type") == "constant": - return Const(mconf.get("value"), output_name=mconf.get("name", "")) - elif mconf.get("type") == "hyperUniqueCardinality": - return HyperUniqueCardinality(mconf.get("name")) - elif mconf.get("type") == "arithmetic": - return Postaggregator( - mconf.get("fn", "/"), mconf.get("fields", []), mconf.get("name", "") - ) - else: - return CustomPostAggregator(mconf.get("name", ""), mconf) - - @staticmethod - def find_postaggs_for( - postagg_names: Set[str], metrics_dict: Dict[str, DruidMetric] - ) -> List[DruidMetric]: - """Return a list of metrics that are post aggregations""" - postagg_metrics = [ - metrics_dict[name] - for name in postagg_names - if metrics_dict[name].metric_type == POST_AGG_TYPE - ] - # Remove post aggregations that were found - for postagg in postagg_metrics: - postagg_names.remove(postagg.metric_name) - return postagg_metrics - - @staticmethod - def recursive_get_fields(_conf: Dict[str, Any]) -> List[str]: - _type = _conf.get("type") - _field = _conf.get("field") - _fields = _conf.get("fields") - field_names = [] - if _type in ["fieldAccess", "hyperUniqueCardinality", "quantile", "quantiles"]: - field_names.append(_conf.get("fieldName", "")) - if _field: - field_names += DruidDatasource.recursive_get_fields(_field) - if _fields: - for _f in _fields: - field_names += DruidDatasource.recursive_get_fields(_f) - return list(set(field_names)) - - @staticmethod - def resolve_postagg( - postagg: DruidMetric, - post_aggs: Dict[str, Any], - agg_names: Set[str], - visited_postaggs: Set[str], - metrics_dict: Dict[str, DruidMetric], - ) -> None: - mconf = postagg.json_obj - required_fields = set( - DruidDatasource.recursive_get_fields(mconf) + mconf.get("fieldNames", []) - ) - # Check if the fields are already in aggs - # or is a previous postagg - required_fields = set( - field - for field in required_fields - if field not in visited_postaggs and field not in agg_names - ) - # First try to find postaggs that match - if len(required_fields) > 0: - missing_postaggs = DruidDatasource.find_postaggs_for( - required_fields, metrics_dict - ) - for missing_metric in required_fields: - agg_names.add(missing_metric) - for missing_postagg in missing_postaggs: - # Add to visited first to avoid infinite recursion - # if post aggregations are cyclicly dependent - visited_postaggs.add(missing_postagg.metric_name) - for missing_postagg in missing_postaggs: - DruidDatasource.resolve_postagg( - missing_postagg, - post_aggs, - agg_names, - visited_postaggs, - metrics_dict, - ) - post_aggs[postagg.metric_name] = DruidDatasource.get_post_agg(postagg.json_obj) - - @staticmethod - def metrics_and_post_aggs( - metrics: List[Metric], metrics_dict: Dict[str, DruidMetric] - ) -> Tuple["OrderedDict[str, Any]", "OrderedDict[str, Any]"]: - # Separate metrics into those that are aggregations - # and those that are post aggregations - saved_agg_names = set() - adhoc_agg_configs = [] - postagg_names = [] - for metric in metrics: - if isinstance(metric, dict) and utils.is_adhoc_metric(metric): - adhoc_agg_configs.append(metric) - elif isinstance(metric, str): - if metrics_dict[metric].metric_type != POST_AGG_TYPE: - saved_agg_names.add(metric) - else: - postagg_names.append(metric) - # Create the post aggregations, maintain order since postaggs - # may depend on previous ones - post_aggs: "OrderedDict[str, Postaggregator]" = OrderedDict() - visited_postaggs = set() - for postagg_name in postagg_names: - postagg = metrics_dict[postagg_name] - visited_postaggs.add(postagg_name) - DruidDatasource.resolve_postagg( - postagg, post_aggs, saved_agg_names, visited_postaggs, metrics_dict - ) - aggs = DruidDatasource.get_aggregations( - metrics_dict, saved_agg_names, adhoc_agg_configs - ) - return aggs, post_aggs - - def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]: - """Retrieve some values for the given column""" - logger.info( - "Getting values for columns [{}] limited to [{}]".format(column_name, limit) - ) - # TODO: Use Lexicographic TopNMetricSpec once supported by PyDruid - if self.fetch_values_from: - from_dttm = parse_human_datetime(self.fetch_values_from) - assert from_dttm - else: - from_dttm = datetime(1970, 1, 1) - - qry = dict( - datasource=self.datasource_name, - granularity="all", - intervals=from_dttm.isoformat() + "/" + datetime.now().isoformat(), - aggregations=dict(count=count("count")), - dimension=column_name, - metric="count", - threshold=limit, - ) - - client = self.cluster.get_pydruid_client() - client.topn(**qry) - df = client.export_pandas() - return df[column_name].to_list() - - def get_query_str( - self, - query_obj: QueryObjectDict, - phase: int = 1, - client: Optional["PyDruid"] = None, - ) -> str: - return self.run_query(client=client, phase=phase, **query_obj) - - def _add_filter_from_pre_query_data( - self, df: pd.DataFrame, dimensions: List[Any], dim_filter: "Filter" - ) -> "Filter": - ret = dim_filter - if not df.empty: - new_filters = [] - for unused, row in df.iterrows(): - fields = [] - for dim in dimensions: - f = None - # Check if this dimension uses an extraction function - # If so, create the appropriate pydruid extraction object - if isinstance(dim, dict) and "extractionFn" in dim: - (col, extraction_fn) = DruidDatasource._create_extraction_fn( - dim - ) - dim_val = dim["outputName"] - f = Filter( - dimension=col, - value=row[dim_val], - extraction_function=extraction_fn, - ) - elif isinstance(dim, dict): - dim_val = dim["outputName"] - if dim_val: - f = Dimension(dim_val) == row[dim_val] - else: - f = Dimension(dim) == row[dim] - if f: - fields.append(f) - if len(fields) > 1: - term = Filter(type="and", fields=fields) - new_filters.append(term) - elif fields: - new_filters.append(fields[0]) - if new_filters: - ff = Filter(type="or", fields=new_filters) - if not dim_filter: - ret = ff - else: - ret = Filter(type="and", fields=[ff, dim_filter]) - return ret - - @staticmethod - def druid_type_from_adhoc_metric(adhoc_metric: AdhocMetric) -> str: - column_type = adhoc_metric["column"]["type"].lower() # type: ignore - aggregate = adhoc_metric["aggregate"].lower() - - if aggregate == "count": - return "count" - if aggregate == "count_distinct": - return "hyperUnique" if column_type == "hyperunique" else "cardinality" - else: - return column_type + aggregate.capitalize() - - @staticmethod - def get_aggregations( - metrics_dict: Dict[str, Any], - saved_metrics: Set[str], - adhoc_metrics: Optional[List[AdhocMetric]] = None, - ) -> "OrderedDict[str, Any]": - """ - Returns a dictionary of aggregation metric names to aggregation json objects - - :param metrics_dict: dictionary of all the metrics - :param saved_metrics: list of saved metric names - :param adhoc_metrics: list of adhoc metric names - :raise SupersetException: if one or more metric names are not aggregations - """ - if not adhoc_metrics: - adhoc_metrics = [] - aggregations = OrderedDict() - invalid_metric_names = [] - for metric_name in saved_metrics: - if metric_name in metrics_dict: - metric = metrics_dict[metric_name] - if metric.metric_type == POST_AGG_TYPE: - invalid_metric_names.append(metric_name) - else: - aggregations[metric_name] = metric.json_obj - else: - invalid_metric_names.append(metric_name) - if len(invalid_metric_names) > 0: - raise SupersetException( - _("Metric(s) {} must be aggregations.").format(invalid_metric_names) - ) - for adhoc_metric in adhoc_metrics: - label = get_metric_name(adhoc_metric) - column = cast(AdhocMetricColumn, adhoc_metric["column"]) - aggregations[label] = { - "fieldName": column["column_name"], - "fieldNames": [column["column_name"]], - "type": DruidDatasource.druid_type_from_adhoc_metric(adhoc_metric), - "name": label, - } - return aggregations - - def get_dimensions( - self, columns: List[str], columns_dict: Dict[str, DruidColumn] - ) -> List[Union[str, Dict[str, Any]]]: - dimensions = [] - columns = [col for col in columns if col in columns_dict] - for column_name in columns: - col = columns_dict.get(column_name) - dim_spec = col.dimension_spec if col else None - dimensions.append(dim_spec or column_name) - return dimensions - - def intervals_from_dttms(self, from_dttm: datetime, to_dttm: datetime) -> str: - # Couldn't find a way to just not filter on time... - from_dttm = from_dttm or datetime(1901, 1, 1) - to_dttm = to_dttm or datetime(2101, 1, 1) - - # add tzinfo to native datetime with config - from_dttm = from_dttm.replace(tzinfo=DRUID_TZ) - to_dttm = to_dttm.replace(tzinfo=DRUID_TZ) - return "{}/{}".format( - from_dttm.isoformat() if from_dttm else "", - to_dttm.isoformat() if to_dttm else "", - ) - - @staticmethod - def _dimensions_to_values( - dimensions: List[Union[Dict[str, str], str]] - ) -> List[Union[Dict[str, str], str]]: - """ - Replace dimensions specs with their `dimension` - values, and ignore those without - """ - values: List[Union[Dict[str, str], str]] = [] - for dimension in dimensions: - if isinstance(dimension, dict): - if "extractionFn" in dimension: - values.append(dimension) - elif "dimension" in dimension: - values.append(dimension["dimension"]) - else: - values.append(dimension) - - return values - - @staticmethod - def sanitize_metric_object(metric: Metric) -> None: - """ - Update a metric with the correct type if necessary. - :param dict metric: The metric to sanitize - """ - if ( - utils.is_adhoc_metric(metric) - and metric["column"]["type"].upper() == "FLOAT" # type: ignore - ): - metric["column"]["type"] = "DOUBLE" # type: ignore - - def run_query( # druid - self, - metrics: List[Metric], - granularity: str, - from_dttm: datetime, - to_dttm: datetime, - columns: Optional[List[str]] = None, - groupby: Optional[List[str]] = None, - filter: Optional[List[Dict[str, Any]]] = None, - is_timeseries: Optional[bool] = True, - timeseries_limit: Optional[int] = None, - timeseries_limit_metric: Optional[Metric] = None, - row_limit: Optional[int] = None, - row_offset: Optional[int] = None, - inner_from_dttm: Optional[datetime] = None, - inner_to_dttm: Optional[datetime] = None, - orderby: Optional[Any] = None, - extras: Optional[Dict[str, Any]] = None, - phase: int = 2, - client: Optional["PyDruid"] = None, - order_desc: bool = True, - is_rowcount: bool = False, - apply_fetch_values_predicate: bool = False, - ) -> str: - """Runs a query against Druid and returns a dataframe.""" - # is_rowcount and apply_fetch_values_predicate is only - # supported on SQL connector - if is_rowcount: - raise SupersetException("is_rowcount is not supported on Druid connector") - if apply_fetch_values_predicate: - raise SupersetException( - "apply_fetch_values_predicate is not supported on Druid connector" - ) - - # TODO refactor into using a TBD Query object - client = client or self.cluster.get_pydruid_client() - row_limit = row_limit or conf.get("ROW_LIMIT") - if row_offset: - raise SupersetException("Offset not implemented for Druid connector") - - if not is_timeseries: - granularity = "all" - - if granularity == "all": - phase = 1 - inner_from_dttm = inner_from_dttm or from_dttm - inner_to_dttm = inner_to_dttm or to_dttm - - timezone = from_dttm.replace(tzinfo=DRUID_TZ).tzname() if from_dttm else None - - query_str = "" - metrics_dict = {m.metric_name: m for m in self.metrics} - columns_dict = {c.column_name: c for c in self.columns} - - if self.cluster and LooseVersion( - self.cluster.get_druid_version() - ) < LooseVersion("0.11.0"): - for metric in metrics: - self.sanitize_metric_object(metric) - if timeseries_limit_metric: - self.sanitize_metric_object(timeseries_limit_metric) - - aggregations, post_aggs = DruidDatasource.metrics_and_post_aggs( - metrics, metrics_dict - ) - - # the dimensions list with dimensionSpecs expanded - dimensions = self.get_dimensions(groupby, columns_dict) if groupby else [] - - extras = extras or {} - qry = dict( - datasource=self.datasource_name, - dimensions=dimensions, - aggregations=aggregations, - granularity=DruidDatasource.granularity( - granularity, timezone=timezone, origin=extras.get("druid_time_origin") - ), - post_aggregations=post_aggs, - intervals=self.intervals_from_dttms(from_dttm, to_dttm), - ) - - if is_timeseries: - qry["context"] = dict(skipEmptyBuckets=True) - - filters = ( - DruidDatasource.get_filters(filter, self.num_cols, columns_dict) - if filter - else None - ) - if filters: - qry["filter"] = filters - - if "having_druid" in extras: - having_filters = self.get_having_filters(extras["having_druid"]) - if having_filters: - qry["having"] = having_filters - else: - having_filters = None - - order_direction = "descending" if order_desc else "ascending" - - if columns: - columns.append("__time") - del qry["post_aggregations"] - del qry["aggregations"] - del qry["dimensions"] - qry["columns"] = columns - qry["metrics"] = [] - qry["granularity"] = "all" - qry["limit"] = row_limit - client.scan(**qry) - elif not groupby and not having_filters: - logger.info("Running timeseries query for no groupby values") - del qry["dimensions"] - client.timeseries(**qry) - elif not having_filters and order_desc and (groupby and len(groupby) == 1): - dim = list(qry["dimensions"])[0] - logger.info("Running two-phase topn query for dimension [{}]".format(dim)) - pre_qry = deepcopy(qry) - order_by: Optional[str] = None - if timeseries_limit_metric: - order_by = utils.get_metric_name(timeseries_limit_metric) - aggs_dict, post_aggs_dict = DruidDatasource.metrics_and_post_aggs( - [timeseries_limit_metric], metrics_dict - ) - if phase == 1: - pre_qry["aggregations"].update(aggs_dict) - pre_qry["post_aggregations"].update(post_aggs_dict) - else: - pre_qry["aggregations"] = aggs_dict - pre_qry["post_aggregations"] = post_aggs_dict - else: - agg_keys = qry["aggregations"].keys() - order_by = list(agg_keys)[0] if agg_keys else None - - # Limit on the number of timeseries, doing a two-phases query - pre_qry["granularity"] = "all" - pre_qry["threshold"] = min(row_limit, timeseries_limit or row_limit) - pre_qry["metric"] = order_by - pre_qry["dimension"] = self._dimensions_to_values(qry["dimensions"])[0] - del pre_qry["dimensions"] - - client.topn(**pre_qry) - logger.info("Phase 1 Complete") - if phase == 2: - query_str += "// Two phase query\n// Phase 1\n" - query_str += json.dumps( - client.query_builder.last_query.query_dict, indent=2 - ) - query_str += "\n" - if phase == 1: - return query_str - query_str += "// Phase 2 (built based on phase one's results)\n" - df = client.export_pandas() - if df is None: - df = pd.DataFrame() - qry["filter"] = self._add_filter_from_pre_query_data( - df, [pre_qry["dimension"]], filters - ) - qry["threshold"] = timeseries_limit or 1000 - if row_limit and granularity == "all": - qry["threshold"] = row_limit - qry["dimension"] = dim - del qry["dimensions"] - qry["metric"] = list(qry["aggregations"].keys())[0] - client.topn(**qry) - logger.info("Phase 2 Complete") - elif having_filters or groupby: - # If grouping on multiple fields or using a having filter - # we have to force a groupby query - logger.info("Running groupby query for dimensions [{}]".format(dimensions)) - if timeseries_limit and is_timeseries: - logger.info("Running two-phase query for timeseries") - - pre_qry = deepcopy(qry) - pre_qry_dims = self._dimensions_to_values(qry["dimensions"]) - - # Can't use set on an array with dicts - # Use set with non-dict items only - non_dict_dims = list( - set([x for x in pre_qry_dims if not isinstance(x, dict)]) - ) - dict_dims = [x for x in pre_qry_dims if isinstance(x, dict)] - pre_qry["dimensions"] = non_dict_dims + dict_dims # type: ignore - - order_by = None - if metrics: - order_by = utils.get_metric_name(metrics[0]) - else: - order_by = pre_qry_dims[0] # type: ignore - - if timeseries_limit_metric: - order_by = utils.get_metric_name(timeseries_limit_metric) - aggs_dict, post_aggs_dict = DruidDatasource.metrics_and_post_aggs( - [timeseries_limit_metric], metrics_dict - ) - if phase == 1: - pre_qry["aggregations"].update(aggs_dict) - pre_qry["post_aggregations"].update(post_aggs_dict) - else: - pre_qry["aggregations"] = aggs_dict - pre_qry["post_aggregations"] = post_aggs_dict - - # Limit on the number of timeseries, doing a two-phases query - pre_qry["granularity"] = "all" - pre_qry["limit_spec"] = { - "type": "default", - "limit": min(timeseries_limit, row_limit), - "intervals": self.intervals_from_dttms( - inner_from_dttm, inner_to_dttm - ), - "columns": [{"dimension": order_by, "direction": order_direction}], - } - client.groupby(**pre_qry) - logger.info("Phase 1 Complete") - query_str += "// Two phase query\n// Phase 1\n" - query_str += json.dumps( - client.query_builder.last_query.query_dict, indent=2 - ) - query_str += "\n" - if phase == 1: - return query_str - query_str += "// Phase 2 (built based on phase one's results)\n" - df = client.export_pandas() - if df is None: - df = pd.DataFrame() - qry["filter"] = self._add_filter_from_pre_query_data( - df, pre_qry["dimensions"], filters - ) - qry["limit_spec"] = None - if row_limit: - dimension_values = self._dimensions_to_values(dimensions) - qry["limit_spec"] = { - "type": "default", - "limit": row_limit, - "columns": [ - { - "dimension": ( - utils.get_metric_name(metrics[0]) - if metrics - else dimension_values[0] - ), - "direction": order_direction, - } - ], - } - client.groupby(**qry) - logger.info("Query Complete") - query_str += json.dumps(client.query_builder.last_query.query_dict, indent=2) - return query_str - - @staticmethod - def homogenize_types(df: pd.DataFrame, columns: Iterable[str]) -> pd.DataFrame: - """Converting all columns to strings - - When grouping by a numeric (say FLOAT) column, pydruid returns - strings in the dataframe. This creates issues downstream related - to having mixed types in the dataframe - - Here we replace None with and make the whole series a - str instead of an object. - """ - df[columns] = df[columns].fillna(NULL_STRING).astype("unicode") - return df - - def query(self, query_obj: QueryObjectDict) -> QueryResult: - qry_start_dttm = datetime.now() - client = self.cluster.get_pydruid_client() - query_str = self.get_query_str(client=client, query_obj=query_obj, phase=2) - df = client.export_pandas() - if df is None: - df = pd.DataFrame() - - if df.empty: - return QueryResult( - df=df, query=query_str, duration=datetime.now() - qry_start_dttm - ) - - df = self.homogenize_types(df, query_obj.get("groupby", [])) - df.columns = [ - DTTM_ALIAS if c in ("timestamp", "__time") else c for c in df.columns - ] - - is_timeseries = ( - query_obj["is_timeseries"] if "is_timeseries" in query_obj else True - ) - if not is_timeseries and DTTM_ALIAS in df.columns: - del df[DTTM_ALIAS] - - # Reordering columns - cols: List[str] = [] - if DTTM_ALIAS in df.columns: - cols += [DTTM_ALIAS] - - cols += query_obj.get("groupby") or [] - cols += query_obj.get("columns") or [] - cols += query_obj.get("metrics") or [] - - cols = utils.get_metric_names(cols) - cols = [col for col in cols if col in df.columns] - df = df[cols] - - time_offset = DruidDatasource.time_offset(query_obj["granularity"]) - - def increment_timestamp(ts: str) -> datetime: - dt = parse_human_datetime(ts).replace(tzinfo=DRUID_TZ) - return dt + timedelta(milliseconds=time_offset) - - if DTTM_ALIAS in df.columns and time_offset: - df[DTTM_ALIAS] = df[DTTM_ALIAS].apply(increment_timestamp) - - return QueryResult( - df=df, query=query_str, duration=datetime.now() - qry_start_dttm - ) - - @staticmethod - def _create_extraction_fn( - dim_spec: Dict[str, Any] - ) -> Tuple[ - str, - Union[ - "MapLookupExtraction", - "RegexExtraction", - "RegisteredLookupExtraction", - "TimeFormatExtraction", - ], - ]: - extraction_fn = None - if dim_spec and "extractionFn" in dim_spec: - col = dim_spec["dimension"] - fn = dim_spec["extractionFn"] - ext_type = fn.get("type") - if ext_type == "lookup" and fn["lookup"].get("type") == "map": - replace_missing_values = fn.get("replaceMissingValueWith") - retain_missing_values = fn.get("retainMissingValue", False) - injective = fn.get("isOneToOne", False) - extraction_fn = MapLookupExtraction( - fn["lookup"]["map"], - replace_missing_values=replace_missing_values, - retain_missing_values=retain_missing_values, - injective=injective, - ) - elif ext_type == "regex": - extraction_fn = RegexExtraction(fn["expr"]) - elif ext_type == "registeredLookup": - extraction_fn = RegisteredLookupExtraction(fn.get("lookup")) - elif ext_type == "timeFormat": - extraction_fn = TimeFormatExtraction( - fn.get("format"), fn.get("locale"), fn.get("timeZone") - ) - else: - raise Exception(_("Unsupported extraction function: " + ext_type)) - return (col, extraction_fn) - - @classmethod - def get_filters( - cls, - raw_filters: List[Dict[str, Any]], - num_cols: List[str], - columns_dict: Dict[str, DruidColumn], - ) -> "Filter": - """Given Superset filter data structure, returns pydruid Filter(s)""" - filters = None - for flt in raw_filters: - col: Optional[str] = flt.get("col") - op: Optional[str] = flt["op"].upper() if "op" in flt else None - eq: Optional[FilterValues] = flt.get("val") - if ( - not col - or not op - or ( - eq is None - and op - not in ( - FilterOperator.IS_NULL.value, - FilterOperator.IS_NOT_NULL.value, - ) - ) - ): - continue - - # Check if this dimension uses an extraction function - # If so, create the appropriate pydruid extraction object - column_def = columns_dict.get(col) - dim_spec = column_def.dimension_spec if column_def else None - extraction_fn = None - if dim_spec and "extractionFn" in dim_spec: - (col, extraction_fn) = DruidDatasource._create_extraction_fn(dim_spec) - - cond = None - is_numeric_col = col in num_cols - is_list_target = op in ( - FilterOperator.IN.value, - FilterOperator.NOT_IN.value, - ) - eq = cls.filter_values_handler( - eq, - is_list_target=is_list_target, - target_generic_type=utils.GenericDataType.NUMERIC - if is_numeric_col - else utils.GenericDataType.STRING, - ) - - # For these two ops, could have used Dimension, - # but it doesn't support extraction functions - if op == FilterOperator.EQUALS.value: - cond = Filter( - dimension=col, value=eq, extraction_function=extraction_fn - ) - elif op == FilterOperator.NOT_EQUALS.value: - cond = ~Filter( - dimension=col, value=eq, extraction_function=extraction_fn - ) - elif is_list_target: - eq = cast(List[Any], eq) - fields = [] - # ignore the filter if it has no value - if not len(eq): - continue - # if it uses an extraction fn, use the "in" operator - # as Dimension isn't supported - elif extraction_fn is not None: - cond = Filter( - dimension=col, - values=eq, - type="in", - extraction_function=extraction_fn, - ) - elif len(eq) == 1: - cond = Dimension(col) == eq[0] - else: - for s in eq: - fields.append(Dimension(col) == s) - cond = Filter(type="or", fields=fields) - if op == FilterOperator.NOT_IN.value: - cond = ~cond - elif op == FilterOperator.REGEX.value: - cond = Filter( - extraction_function=extraction_fn, - type="regex", - pattern=eq, - dimension=col, - ) - - # For the ops below, could have used pydruid's Bound, - # but it doesn't support extraction functions - elif op == FilterOperator.GREATER_THAN_OR_EQUALS.value: - cond = Bound( - extraction_function=extraction_fn, - dimension=col, - lowerStrict=False, - upperStrict=False, - lower=eq, - upper=None, - ordering=cls._get_ordering(is_numeric_col), - ) - elif op == FilterOperator.LESS_THAN_OR_EQUALS.value: - cond = Bound( - extraction_function=extraction_fn, - dimension=col, - lowerStrict=False, - upperStrict=False, - lower=None, - upper=eq, - ordering=cls._get_ordering(is_numeric_col), - ) - elif op == FilterOperator.GREATER_THAN.value: - cond = Bound( - extraction_function=extraction_fn, - lowerStrict=True, - upperStrict=False, - dimension=col, - lower=eq, - upper=None, - ordering=cls._get_ordering(is_numeric_col), - ) - elif op == FilterOperator.LESS_THAN.value: - cond = Bound( - extraction_function=extraction_fn, - upperStrict=True, - lowerStrict=False, - dimension=col, - lower=None, - upper=eq, - ordering=cls._get_ordering(is_numeric_col), - ) - elif op == FilterOperator.IS_NULL.value: - cond = Filter(dimension=col, value="") - elif op == FilterOperator.IS_NOT_NULL.value: - cond = ~Filter(dimension=col, value="") - - if filters: - filters = Filter(type="and", fields=[cond, filters]) - else: - filters = cond - - return filters - - @staticmethod - def _get_ordering(is_numeric_col: bool) -> str: - return "numeric" if is_numeric_col else "lexicographic" - - def _get_having_obj(self, col: str, op: str, eq: str) -> "Having": - cond = None - if op == FilterOperator.EQUALS.value: - if col in self.column_names: - cond = DimSelector(dimension=col, value=eq) - else: - cond = Aggregation(col) == eq - elif op == FilterOperator.GREATER_THAN.value: - cond = Aggregation(col) > eq - elif op == FilterOperator.LESS_THAN.value: - cond = Aggregation(col) < eq - - return cond - - def get_having_filters( - self, raw_filters: List[Dict[str, Any]] - ) -> Optional["Having"]: - filters = None - reversed_op_map = { - FilterOperator.NOT_EQUALS.value: FilterOperator.EQUALS.value, - FilterOperator.GREATER_THAN_OR_EQUALS.value: FilterOperator.LESS_THAN.value, - FilterOperator.LESS_THAN_OR_EQUALS.value: FilterOperator.GREATER_THAN.value, - } - - for flt in raw_filters: - if not all(f in flt for f in ["col", "op", "val"]): - continue - col = flt["col"] - op = flt["op"] - eq = flt["val"] - cond = None - if op in [ - FilterOperator.EQUALS.value, - FilterOperator.GREATER_THAN.value, - FilterOperator.LESS_THAN.value, - ]: - cond = self._get_having_obj(col, op, eq) - elif op in reversed_op_map: - cond = ~self._get_having_obj(col, reversed_op_map[op], eq) - - if filters: - filters = filters & cond - else: - filters = cond - return filters - - @classmethod - def query_datasources_by_name( - cls, - session: Session, - database: Database, - datasource_name: str, - schema: Optional[str] = None, - ) -> List["DruidDatasource"]: - return [] - - def external_metadata(self) -> List[Dict[str, Any]]: - self.merge_flag = True - latest_metadata = self.latest_metadata() or {} - return [{"name": k, "type": v.get("type")} for k, v in latest_metadata.items()] - - @staticmethod - def update_datasource( - _mapper: Mapper, _connection: Connection, obj: Union[DruidColumn, DruidMetric] - ) -> None: - """ - Forces an update to the datasource's changed_on value when a metric or column on - the datasource is updated. This busts the cache key for all charts that use the - datasource. - - :param _mapper: Unused. - :param _connection: Unused. - :param obj: The metric or column that was updated. - """ - db.session.execute( - update(DruidDatasource).where(DruidDatasource.id == obj.datasource.id) - ) - - -sa.event.listen(DruidDatasource, "after_insert", security_manager.set_perm) -sa.event.listen(DruidDatasource, "after_update", security_manager.set_perm) -sa.event.listen(DruidMetric, "after_update", DruidDatasource.update_datasource) -sa.event.listen(DruidColumn, "after_update", DruidDatasource.update_datasource) diff --git a/superset/connectors/druid/views.py b/superset/connectors/druid/views.py deleted file mode 100644 index b387aff6962e..000000000000 --- a/superset/connectors/druid/views.py +++ /dev/null @@ -1,445 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import json -import logging -from datetime import datetime - -from flask import current_app as app, flash, Markup, redirect -from flask_appbuilder import CompactCRUDMixin, expose -from flask_appbuilder.fieldwidgets import Select2Widget -from flask_appbuilder.hooks import before_request -from flask_appbuilder.models.sqla.interface import SQLAInterface -from flask_appbuilder.security.decorators import has_access -from flask_babel import lazy_gettext as _ -from werkzeug.exceptions import NotFound -from wtforms import StringField -from wtforms.ext.sqlalchemy.fields import QuerySelectField - -from superset import db, security_manager -from superset.connectors.base.views import BS3TextFieldROWidget, DatasourceModelView -from superset.connectors.connector_registry import ConnectorRegistry -from superset.connectors.druid import models -from superset.constants import RouteMethod -from superset.superset_typing import FlaskResponse -from superset.utils import core as utils -from superset.views.base import ( - BaseSupersetView, - DatasourceFilter, - DeleteMixin, - get_dataset_exist_error_msg, - ListWidgetWithCheckboxes, - SupersetModelView, - validate_json, - YamlExportMixin, -) - -logger = logging.getLogger(__name__) - - -class EnsureEnabledMixin: - @staticmethod - def is_enabled() -> bool: - return bool(app.config["DRUID_IS_ACTIVE"]) - - @before_request - def ensure_enabled(self) -> None: - if not self.is_enabled(): - raise NotFound() - - -class DruidColumnInlineView( # pylint: disable=too-many-ancestors - CompactCRUDMixin, - EnsureEnabledMixin, - SupersetModelView, -): - datamodel = SQLAInterface(models.DruidColumn) - include_route_methods = RouteMethod.RELATED_VIEW_SET - - list_title = _("Columns") - show_title = _("Show Druid Column") - add_title = _("Add Druid Column") - edit_title = _("Edit Druid Column") - - list_widget = ListWidgetWithCheckboxes - - edit_columns = [ - "column_name", - "verbose_name", - "description", - "dimension_spec_json", - "datasource", - "groupby", - "filterable", - ] - add_columns = edit_columns - list_columns = ["column_name", "verbose_name", "type", "groupby", "filterable"] - can_delete = False - page_size = 500 - label_columns = { - "column_name": _("Column"), - "type": _("Type"), - "datasource": _("Datasource"), - "groupby": _("Groupable"), - "filterable": _("Filterable"), - } - description_columns = { - "filterable": _( - "Whether this column is exposed in the `Filters` section " - "of the explore view." - ), - "dimension_spec_json": utils.markdown( - "this field can be used to specify " - "a `dimensionSpec` as documented [here]" - "(http://druid.io/docs/latest/querying/dimensionspecs.html). " - "Make sure to input valid JSON and that the " - "`outputName` matches the `column_name` defined " - "above.", - True, - ), - } - - add_form_extra_fields = { - "datasource": QuerySelectField( - "Datasource", - query_factory=lambda: db.session.query(models.DruidDatasource), - allow_blank=True, - widget=Select2Widget(extra_classes="readonly"), - ) - } - - edit_form_extra_fields = add_form_extra_fields - - def pre_update(self, item: "DruidColumnInlineView") -> None: - # If a dimension spec JSON is given, ensure that it is - # valid JSON and that `outputName` is specified - if item.dimension_spec_json: - try: - dimension_spec = json.loads(item.dimension_spec_json) - except ValueError as ex: - raise ValueError("Invalid Dimension Spec JSON: " + str(ex)) from ex - if not isinstance(dimension_spec, dict): - raise ValueError("Dimension Spec must be a JSON object") - if "outputName" not in dimension_spec: - raise ValueError("Dimension Spec does not contain `outputName`") - if "dimension" not in dimension_spec: - raise ValueError("Dimension Spec is missing `dimension`") - # `outputName` should be the same as the `column_name` - if dimension_spec["outputName"] != item.column_name: - raise ValueError( - "`outputName` [{}] unequal to `column_name` [{}]".format( - dimension_spec["outputName"], item.column_name - ) - ) - - def post_update(self, item: "DruidColumnInlineView") -> None: - item.refresh_metrics() - - def post_add(self, item: "DruidColumnInlineView") -> None: - self.post_update(item) - - -class DruidMetricInlineView( # pylint: disable=too-many-ancestors - CompactCRUDMixin, - EnsureEnabledMixin, - SupersetModelView, -): - datamodel = SQLAInterface(models.DruidMetric) - include_route_methods = RouteMethod.RELATED_VIEW_SET - - list_title = _("Metrics") - show_title = _("Show Druid Metric") - add_title = _("Add Druid Metric") - edit_title = _("Edit Druid Metric") - - list_columns = ["metric_name", "verbose_name", "metric_type"] - edit_columns = [ - "metric_name", - "description", - "verbose_name", - "metric_type", - "json", - "datasource", - "d3format", - "warning_text", - ] - add_columns = edit_columns - page_size = 500 - validators_columns = {"json": [validate_json]} - description_columns = { - "metric_type": utils.markdown( - "use `postagg` as the metric type if you are defining a " - "[Druid Post Aggregation]" - "(http://druid.io/docs/latest/querying/post-aggregations.html)", - True, - ) - } - label_columns = { - "metric_name": _("Metric"), - "description": _("Description"), - "verbose_name": _("Verbose Name"), - "metric_type": _("Type"), - "json": _("JSON"), - "datasource": _("Druid Datasource"), - "warning_text": _("Warning Message"), - } - - add_form_extra_fields = { - "datasource": QuerySelectField( - "Datasource", - query_factory=lambda: db.session.query(models.DruidDatasource), - allow_blank=True, - widget=Select2Widget(extra_classes="readonly"), - ) - } - - edit_form_extra_fields = add_form_extra_fields - - -class DruidClusterModelView( # pylint: disable=too-many-ancestors - EnsureEnabledMixin, - SupersetModelView, - DeleteMixin, - YamlExportMixin, -): - datamodel = SQLAInterface(models.DruidCluster) - include_route_methods = RouteMethod.CRUD_SET - list_title = _("Druid Clusters") - show_title = _("Show Druid Cluster") - add_title = _("Add Druid Cluster") - edit_title = _("Edit Druid Cluster") - - add_columns = [ - "verbose_name", - "broker_host", - "broker_port", - "broker_user", - "broker_pass", - "broker_endpoint", - "cache_timeout", - "cluster_name", - ] - edit_columns = add_columns - list_columns = ["cluster_name", "metadata_last_refreshed"] - search_columns = ("cluster_name",) - label_columns = { - "cluster_name": _("Cluster Name"), - "broker_host": _("Broker Host"), - "broker_port": _("Broker Port"), - "broker_user": _("Broker Username"), - "broker_pass": _("Broker Password"), - "broker_endpoint": _("Broker Endpoint"), - "verbose_name": _("Verbose Name"), - "cache_timeout": _("Cache Timeout"), - "metadata_last_refreshed": _("Metadata Last Refreshed"), - } - description_columns = { - "cache_timeout": _( - "Duration (in seconds) of the caching timeout for this cluster. " - "A timeout of 0 indicates that the cache never expires. " - "Note this defaults to the global timeout if undefined." - ), - "broker_user": _( - "Druid supports basic authentication. See " - "[auth](http://druid.io/docs/latest/design/auth.html) and " - "druid-basic-security extension" - ), - "broker_pass": _( - "Druid supports basic authentication. See " - "[auth](http://druid.io/docs/latest/design/auth.html) and " - "druid-basic-security extension" - ), - } - - yaml_dict_key = "databases" - - def pre_add(self, item: "DruidClusterModelView") -> None: - security_manager.add_permission_view_menu("database_access", item.perm) - - def pre_update(self, item: "DruidClusterModelView") -> None: - self.pre_add(item) - - def _delete(self, pk: int) -> None: - DeleteMixin._delete(self, pk) - - -class DruidDatasourceModelView( # pylint: disable=too-many-ancestors - EnsureEnabledMixin, - DatasourceModelView, - DeleteMixin, - YamlExportMixin, -): - datamodel = SQLAInterface(models.DruidDatasource) - include_route_methods = RouteMethod.CRUD_SET - list_title = _("Druid Datasources") - show_title = _("Show Druid Datasource") - add_title = _("Add Druid Datasource") - edit_title = _("Edit Druid Datasource") - - list_columns = ["datasource_link", "cluster", "changed_by_", "modified"] - order_columns = ["datasource_link", "modified"] - related_views = [DruidColumnInlineView, DruidMetricInlineView] - edit_columns = [ - "datasource_name", - "cluster", - "description", - "owners", - "is_hidden", - "filter_select_enabled", - "fetch_values_from", - "default_endpoint", - "offset", - "cache_timeout", - ] - search_columns = ("datasource_name", "cluster", "description", "owners") - add_columns = edit_columns - show_columns = add_columns + ["perm", "slices"] - page_size = 500 - base_order = ("datasource_name", "asc") - description_columns = { - "slices": _( - "The list of charts associated with this table. By " - "altering this datasource, you may change how these associated " - "charts behave. " - "Also note that charts need to point to a datasource, so " - "this form will fail at saving if removing charts from a " - "datasource. If you want to change the datasource for a chart, " - "overwrite the chart from the 'explore view'" - ), - "offset": _("Timezone offset (in hours) for this datasource"), - "description": Markup( - 'Supports markdown' - ), - "fetch_values_from": _( - "Time expression to use as a predicate when retrieving " - "distinct values to populate the filter component. " - "Only applies when `Enable Filter Select` is on. If " - "you enter `7 days ago`, the distinct list of values in " - "the filter will be populated based on the distinct value over " - "the past week" - ), - "filter_select_enabled": _( - "Whether to populate the filter's dropdown in the explore " - "view's filter section with a list of distinct values fetched " - "from the backend on the fly" - ), - "default_endpoint": _( - "Redirects to this endpoint when clicking on the datasource " - "from the datasource list" - ), - "cache_timeout": _( - "Duration (in seconds) of the caching timeout for this datasource. " - "A timeout of 0 indicates that the cache never expires. " - "Note this defaults to the cluster timeout if undefined." - ), - } - base_filters = [["id", DatasourceFilter, lambda: []]] - label_columns = { - "slices": _("Associated Charts"), - "datasource_link": _("Data Source"), - "cluster": _("Cluster"), - "description": _("Description"), - "owners": _("Owners"), - "is_hidden": _("Is Hidden"), - "filter_select_enabled": _("Enable Filter Select"), - "default_endpoint": _("Default Endpoint"), - "offset": _("Time Offset"), - "cache_timeout": _("Cache Timeout"), - "datasource_name": _("Datasource Name"), - "fetch_values_from": _("Fetch Values From"), - "changed_by_": _("Changed By"), - "modified": _("Modified"), - } - edit_form_extra_fields = { - "cluster": QuerySelectField( - "Cluster", - query_factory=lambda: db.session.query(models.DruidCluster), - widget=Select2Widget(extra_classes="readonly"), - ), - "datasource_name": StringField( - "Datasource Name", widget=BS3TextFieldROWidget() - ), - } - - def pre_add(self, item: "DruidDatasourceModelView") -> None: - with db.session.no_autoflush: - query = db.session.query(models.DruidDatasource).filter( - models.DruidDatasource.datasource_name == item.datasource_name, - models.DruidDatasource.cluster_id == item.cluster_id, - ) - if db.session.query(query.exists()).scalar(): - raise Exception(get_dataset_exist_error_msg(item.full_name)) - - def post_add(self, item: "DruidDatasourceModelView") -> None: - item.refresh_metrics() - security_manager.add_permission_view_menu("datasource_access", item.get_perm()) - if item.schema: - security_manager.add_permission_view_menu("schema_access", item.schema_perm) - - def post_update(self, item: "DruidDatasourceModelView") -> None: - self.post_add(item) - - def _delete(self, pk: int) -> None: - DeleteMixin._delete(self, pk) - - -class Druid(EnsureEnabledMixin, BaseSupersetView): - """The base views for Superset!""" - - @has_access - @expose("/refresh_datasources/") - def refresh_datasources( # pylint: disable=no-self-use - self, refresh_all: bool = True - ) -> FlaskResponse: - """endpoint that refreshes druid datasources metadata""" - session = db.session() - DruidCluster = ConnectorRegistry.sources[ # pylint: disable=invalid-name - "druid" - ].cluster_class - for cluster in session.query(DruidCluster).all(): - cluster_name = cluster.cluster_name - valid_cluster = True - try: - cluster.refresh_datasources(refresh_all=refresh_all) - except Exception as ex: # pylint: disable=broad-except - valid_cluster = False - flash( - "Error while processing cluster '{}'\n{}".format( - cluster_name, utils.error_msg_from_exception(ex) - ), - "danger", - ) - logger.exception(ex) - if valid_cluster: - cluster.metadata_last_refreshed = datetime.now() - flash( - _("Refreshed metadata from cluster [{}]").format( - cluster.cluster_name - ), - "info", - ) - session.commit() - return redirect("/druiddatasourcemodelview/list/") - - @has_access - @expose("/scan_new_datasources/") - def scan_new_datasources(self) -> FlaskResponse: - """ - Calling this endpoint will cause a scan for new - datasources only and add them. - """ - return self.refresh_datasources(refresh_all=False) diff --git a/superset/dashboards/commands/importers/v0.py b/superset/dashboards/commands/importers/v0.py index a7fbb51c057d..207920b1d2c2 100644 --- a/superset/dashboards/commands/importers/v0.py +++ b/superset/dashboards/commands/importers/v0.py @@ -269,20 +269,11 @@ def alter_native_filters(dashboard: Dashboard) -> None: return dashboard_to_import.id # type: ignore -def decode_dashboards( # pylint: disable=too-many-return-statements - o: Dict[str, Any] -) -> Any: +def decode_dashboards(o: Dict[str, Any]) -> Any: """ Function to be passed into json.loads obj_hook parameter Recreates the dashboard object from a json representation. """ - # pylint: disable=import-outside-toplevel - from superset.connectors.druid.models import ( - DruidCluster, - DruidColumn, - DruidDatasource, - DruidMetric, - ) if "__Dashboard__" in o: return Dashboard(**o["__Dashboard__"]) @@ -294,14 +285,6 @@ def decode_dashboards( # pylint: disable=too-many-return-statements return SqlaTable(**o["__SqlaTable__"]) if "__SqlMetric__" in o: return SqlMetric(**o["__SqlMetric__"]) - if "__DruidCluster__" in o: - return DruidCluster(**o["__DruidCluster__"]) - if "__DruidColumn__" in o: - return DruidColumn(**o["__DruidColumn__"]) - if "__DruidDatasource__" in o: - return DruidDatasource(**o["__DruidDatasource__"]) - if "__DruidMetric__" in o: - return DruidMetric(**o["__DruidMetric__"]) if "__datetime__" in o: return datetime.strptime(o["__datetime__"], "%Y-%m-%dT%H:%M:%S") diff --git a/superset/datasets/commands/importers/v0.py b/superset/datasets/commands/importers/v0.py index 7f13261edd3d..74b9ca0a6abb 100644 --- a/superset/datasets/commands/importers/v0.py +++ b/superset/datasets/commands/importers/v0.py @@ -27,16 +27,11 @@ from superset.commands.base import BaseCommand from superset.commands.importers.exceptions import IncorrectVersionError from superset.connectors.base.models import BaseColumn, BaseDatasource, BaseMetric -from superset.connectors.druid.models import ( - DruidCluster, - DruidColumn, - DruidDatasource, - DruidMetric, -) from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn from superset.databases.commands.exceptions import DatabaseNotFoundError +from superset.datasets.commands.exceptions import DatasetInvalidError from superset.models.core import Database -from superset.utils.dict_import_export import DATABASES_KEY, DRUID_CLUSTERS_KEY +from superset.utils.dict_import_export import DATABASES_KEY logger = logging.getLogger(__name__) @@ -65,21 +60,6 @@ def lookup_sqla_database(table: SqlaTable) -> Optional[Database]: return database -def lookup_druid_cluster(datasource: DruidDatasource) -> Optional[DruidCluster]: - return db.session.query(DruidCluster).filter_by(id=datasource.cluster_id).first() - - -def lookup_druid_datasource(datasource: DruidDatasource) -> Optional[DruidDatasource]: - return ( - db.session.query(DruidDatasource) - .filter( - DruidDatasource.datasource_name == datasource.datasource_name, - DruidDatasource.cluster_id == datasource.cluster_id, - ) - .first() - ) - - def import_dataset( i_datasource: BaseDatasource, database_id: Optional[int] = None, @@ -97,9 +77,9 @@ def import_dataset( if isinstance(i_datasource, SqlaTable): lookup_database = lookup_sqla_database lookup_datasource = lookup_sqla_table + else: - lookup_database = lookup_druid_cluster - lookup_datasource = lookup_druid_datasource + raise DatasetInvalidError return import_datasource( db.session, @@ -122,22 +102,11 @@ def lookup_sqla_metric(session: Session, metric: SqlMetric) -> SqlMetric: ) -def lookup_druid_metric(session: Session, metric: DruidMetric) -> DruidMetric: - return ( - session.query(DruidMetric) - .filter( - DruidMetric.datasource_id == metric.datasource_id, - DruidMetric.metric_name == metric.metric_name, - ) - .first() - ) - - def import_metric(session: Session, metric: BaseMetric) -> BaseMetric: if isinstance(metric, SqlMetric): lookup_metric = lookup_sqla_metric else: - lookup_metric = lookup_druid_metric + raise Exception(f"Invalid metric type: {metric}") return import_simple_obj(session, metric, lookup_metric) @@ -152,22 +121,11 @@ def lookup_sqla_column(session: Session, column: TableColumn) -> TableColumn: ) -def lookup_druid_column(session: Session, column: DruidColumn) -> DruidColumn: - return ( - session.query(DruidColumn) - .filter( - DruidColumn.datasource_id == column.datasource_id, - DruidColumn.column_name == column.column_name, - ) - .first() - ) - - def import_column(session: Session, column: BaseColumn) -> BaseColumn: if isinstance(column, TableColumn): lookup_column = lookup_sqla_column else: - lookup_column = lookup_druid_column + raise Exception(f"Invalid column type: {column}") return import_simple_obj(session, column, lookup_column) @@ -257,19 +215,13 @@ def import_simple_obj( def import_from_dict( session: Session, data: Dict[str, Any], sync: Optional[List[str]] = None ) -> None: - """Imports databases and druid clusters from dictionary""" + """Imports databases from dictionary""" if not sync: sync = [] if isinstance(data, dict): logger.info("Importing %d %s", len(data.get(DATABASES_KEY, [])), DATABASES_KEY) for database in data.get(DATABASES_KEY, []): Database.import_from_dict(session, database, sync=sync) - - logger.info( - "Importing %d %s", len(data.get(DRUID_CLUSTERS_KEY, [])), DRUID_CLUSTERS_KEY - ) - for datasource in data.get(DRUID_CLUSTERS_KEY, []): - DruidCluster.import_from_dict(session, datasource, sync=sync) session.commit() else: logger.info("Supplied object is not a dictionary.") @@ -334,7 +286,7 @@ def validate(self) -> None: # CLI export if isinstance(config, dict): # TODO (betodealmeida): validate with Marshmallow - if DATABASES_KEY not in config and DRUID_CLUSTERS_KEY not in config: + if DATABASES_KEY not in config: raise IncorrectVersionError(f"{file_name} has no valid keys") # UI export diff --git a/superset/initialization/__init__.py b/superset/initialization/__init__.py index dff5e2a7abe7..5f216772168d 100644 --- a/superset/initialization/__init__.py +++ b/superset/initialization/__init__.py @@ -118,13 +118,6 @@ def init_views(self) -> None: from superset.cachekeys.api import CacheRestApi from superset.charts.api import ChartRestApi from superset.charts.data.api import ChartDataRestApi - from superset.connectors.druid.views import ( - Druid, - DruidClusterModelView, - DruidColumnInlineView, - DruidDatasourceModelView, - DruidMetricInlineView, - ) from superset.connectors.sqla.views import ( RowLevelSecurityFiltersModelView, SqlMetricInlineView, @@ -151,7 +144,7 @@ def init_views(self) -> None: from superset.reports.logs.api import ReportExecutionLogRestApi from superset.security.api import SecurityRestApi from superset.views.access_requests import AccessRequestsModelView - from superset.views.alerts import AlertView, ReportView + from superset.views.alerts import AlertView from superset.views.annotations import ( AnnotationLayerModelView, AnnotationModelView, @@ -405,66 +398,6 @@ def init_views(self) -> None: menu_cond=lambda: bool(self.config["ENABLE_ACCESS_REQUEST"]), ) - # - # Druid Views - # - appbuilder.add_separator( - "Data", cond=lambda: bool(self.config["DRUID_IS_ACTIVE"]) - ) - appbuilder.add_view( - DruidDatasourceModelView, - "Druid Datasources", - label=__("Druid Datasources"), - category="Data", - category_label=__("Data"), - icon="fa-cube", - menu_cond=lambda: bool(self.config["DRUID_IS_ACTIVE"]), - ) - appbuilder.add_view( - DruidClusterModelView, - name="Druid Clusters", - label=__("Druid Clusters"), - icon="fa-cubes", - category="Data", - category_label=__("Data"), - category_icon="fa-database", - menu_cond=lambda: bool(self.config["DRUID_IS_ACTIVE"]), - ) - appbuilder.add_view_no_menu(DruidMetricInlineView) - appbuilder.add_view_no_menu(DruidColumnInlineView) - appbuilder.add_view_no_menu(Druid) - - appbuilder.add_link( - "Scan New Datasources", - label=__("Scan New Datasources"), - href="/druid/scan_new_datasources/", - category="Data", - category_label=__("Data"), - category_icon="fa-database", - icon="fa-refresh", - cond=lambda: bool( - self.config["DRUID_IS_ACTIVE"] - and self.config["DRUID_METADATA_LINKS_ENABLED"] - ), - ) - appbuilder.add_view_no_menu(ReportView) - appbuilder.add_link( - "Refresh Druid Metadata", - label=__("Refresh Druid Metadata"), - href="/druid/refresh_datasources/", - category="Data", - category_label=__("Data"), - category_icon="fa-database", - icon="fa-cog", - cond=lambda: bool( - self.config["DRUID_IS_ACTIVE"] - and self.config["DRUID_METADATA_LINKS_ENABLED"] - ), - ) - appbuilder.add_separator( - "Data", cond=lambda: bool(self.config["DRUID_IS_ACTIVE"]) - ) - def init_app_in_ctx(self) -> None: """ Runs init logic in the context of the app diff --git a/superset/models/dashboard.py b/superset/models/dashboard.py index f9dde531d37f..699cd6dd1bbf 100644 --- a/superset/models/dashboard.py +++ b/superset/models/dashboard.py @@ -49,7 +49,6 @@ from superset import app, ConnectorRegistry, db, is_feature_enabled, security_manager from superset.common.request_contexed_based import is_user_admin from superset.connectors.base.models import BaseDatasource -from superset.connectors.druid.models import DruidColumn, DruidMetric from superset.connectors.sqla.models import SqlMetric, TableColumn from superset.extensions import cache_manager from superset.models.filter_set import FilterSet @@ -488,8 +487,6 @@ def clear_dashboard_cache( Dashboard.clear_cache_for_datasource(datasource_id=obj.id) elif isinstance(obj, (SqlMetric, TableColumn)): Dashboard.clear_cache_for_datasource(datasource_id=obj.table_id) - elif isinstance(obj, (DruidMetric, DruidColumn)): - Dashboard.clear_cache_for_datasource(datasource_id=obj.datasource_id) sqla.event.listen(Dashboard, "after_update", clear_dashboard_cache) sqla.event.listen( @@ -504,5 +501,3 @@ def clear_dashboard_cache( # trigger update events for BaseDatasource. sqla.event.listen(SqlMetric, "after_update", clear_dashboard_cache) sqla.event.listen(TableColumn, "after_update", clear_dashboard_cache) - sqla.event.listen(DruidMetric, "after_update", clear_dashboard_cache) - sqla.event.listen(DruidColumn, "after_update", clear_dashboard_cache) diff --git a/superset/security/manager.py b/superset/security/manager.py index 48d43d01d0f7..c422c2e8bbe3 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -79,7 +79,6 @@ if TYPE_CHECKING: from superset.common.query_context import QueryContext from superset.connectors.base.models import BaseDatasource - from superset.connectors.druid.models import DruidCluster from superset.models.core import Database from superset.models.dashboard import Dashboard from superset.models.sql_lab import Query @@ -153,9 +152,6 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods GAMMA_READ_ONLY_MODEL_VIEWS = { "Dataset", - "DruidColumnInlineView", - "DruidDatasourceModelView", - "DruidMetricInlineView", "Datasource", } | READ_ONLY_MODEL_VIEWS @@ -325,7 +321,7 @@ def can_access_all_databases(self) -> bool: return self.can_access("all_database_access", "all_database_access") - def can_access_database(self, database: Union["Database", "DruidCluster"]) -> bool: + def can_access_database(self, database: "Database") -> bool: """ Return True if the user can fully access the Superset database, False otherwise. diff --git a/superset/utils/dict_import_export.py b/superset/utils/dict_import_export.py index 1924d5964294..93070732e78f 100644 --- a/superset/utils/dict_import_export.py +++ b/superset/utils/dict_import_export.py @@ -19,12 +19,10 @@ from sqlalchemy.orm import Session -from superset.connectors.druid.models import DruidCluster from superset.models.core import Database EXPORT_VERSION = "1.0.0" DATABASES_KEY = "databases" -DRUID_CLUSTERS_KEY = "druid_clusters" logger = logging.getLogger(__name__) @@ -33,14 +31,9 @@ def export_schema_to_dict(back_references: bool) -> Dict[str, Any]: databases = [ Database.export_schema(recursive=True, include_parent_ref=back_references) ] - clusters = [ - DruidCluster.export_schema(recursive=True, include_parent_ref=back_references) - ] data = {} if databases: data[DATABASES_KEY] = databases - if clusters: - data[DRUID_CLUSTERS_KEY] = clusters return data @@ -59,19 +52,7 @@ def export_to_dict( for database in dbs ] logger.info("Exported %d %s", len(databases), DATABASES_KEY) - cls = session.query(DruidCluster) - clusters = [ - cluster.export_to_dict( - recursive=recursive, - include_parent_ref=back_references, - include_defaults=include_defaults, - ) - for cluster in cls - ] - logger.info("Exported %d %s", len(clusters), DRUID_CLUSTERS_KEY) data = {} if databases: data[DATABASES_KEY] = databases - if clusters: - data[DRUID_CLUSTERS_KEY] = clusters return data diff --git a/superset/views/base.py b/superset/views/base.py index 4695eab39562..17183e59a774 100644 --- a/superset/views/base.py +++ b/superset/views/base.py @@ -20,7 +20,7 @@ import logging import traceback from datetime import datetime -from typing import Any, Callable, cast, Dict, List, Optional, TYPE_CHECKING, Union +from typing import Any, Callable, cast, Dict, List, Optional, Union import simplejson as json import yaml @@ -80,16 +80,12 @@ from .utils import bootstrap_user_data -if TYPE_CHECKING: - from superset.connectors.druid.views import DruidClusterModelView - FRONTEND_CONF_KEYS = ( "SUPERSET_WEBSERVER_TIMEOUT", "SUPERSET_DASHBOARD_POSITION_DATA_LIMIT", "SUPERSET_DASHBOARD_PERIODICAL_REFRESH_LIMIT", "SUPERSET_DASHBOARD_PERIODICAL_REFRESH_WARNING_MESSAGE", "DISABLE_DATASET_SOURCE_EDIT", - "DRUID_IS_ACTIVE", "ENABLE_JAVASCRIPT_CONTROLS", "DEFAULT_SQLLAB_LIMIT", "DEFAULT_VIZ_TYPE", diff --git a/superset/views/core.py b/superset/views/core.py index 9f81c071642f..64a547348d6a 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -1998,61 +1998,6 @@ def dashboard_permalink( # pylint: disable=no-self-use def log(self) -> FlaskResponse: # pylint: disable=no-self-use return Response(status=200) - @has_access - @expose("/sync_druid/", methods=["POST"]) - @event_logger.log_this - def sync_druid_source(self) -> FlaskResponse: # pylint: disable=no-self-use - """Syncs the druid datasource in main db with the provided config. - - The endpoint takes 3 arguments: - user - user name to perform the operation as - cluster - name of the druid cluster - config - configuration stored in json that contains: - name: druid datasource name - dimensions: list of the dimensions, they become druid columns - with the type STRING - metrics_spec: list of metrics (dictionary). Metric consists of - 2 attributes: type and name. Type can be count, - etc. `count` type is stored internally as longSum - other fields will be ignored. - - Example: { - 'name': 'test_click', - 'metrics_spec': [{'type': 'count', 'name': 'count'}], - 'dimensions': ['affiliate_id', 'campaign', 'first_seen'] - } - """ - payload = request.get_json(force=True) - druid_config = payload["config"] - user_name = payload["user"] - cluster_name = payload["cluster"] - - user = security_manager.find_user(username=user_name) - DruidDatasource = ConnectorRegistry.sources[ # pylint: disable=invalid-name - "druid" - ] - DruidCluster = DruidDatasource.cluster_class # pylint: disable=invalid-name - if not user: - err_msg = __("Can't find user, please ask your admin to create one.") - logger.error(err_msg, exc_info=True) - return json_error_response(err_msg) - cluster = ( - db.session.query(DruidCluster) - .filter_by(cluster_name=cluster_name) - .one_or_none() - ) - if not cluster: - err_msg = __("Can't find DruidCluster") - logger.error(err_msg, exc_info=True) - return json_error_response(err_msg) - try: - DruidDatasource.sync_to_db_from_config(druid_config, user, cluster) - except Exception as ex: # pylint: disable=broad-except - err_msg = utils.error_msg_from_exception(ex) - logger.exception(err_msg) - return json_error_response(err_msg) - return Response(status=201) - @has_access @expose("/get_or_create_table/", methods=["POST"]) @event_logger.log_this diff --git a/tests/integration_tests/access_tests.py b/tests/integration_tests/access_tests.py index abefc58c9bc6..d26b07504ced 100644 --- a/tests/integration_tests/access_tests.py +++ b/tests/integration_tests/access_tests.py @@ -38,7 +38,6 @@ from tests.integration_tests.test_app import app # isort:skip from superset import db, security_manager from superset.connectors.connector_registry import ConnectorRegistry -from superset.connectors.druid.models import DruidDatasource from superset.connectors.sqla.models import SqlaTable from superset.models import core as models from superset.models.datasource_access_request import DatasourceAccessRequest @@ -114,8 +113,6 @@ class TestRequestAccess(SupersetTestCase): @classmethod def setUpClass(cls): with app.app_context(): - cls.create_druid_test_objects() - security_manager.add_role("override_me") security_manager.add_role(TEST_ROLE_1) security_manager.add_role(TEST_ROLE_2) @@ -181,40 +178,6 @@ def test_override_role_permissions_1_table(self): "datasource_access", updated_override_me.permissions[0].permission.name ) - @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") - def test_override_role_permissions_druid_and_table(self): - database = get_example_database() - engine = database.get_sqla_engine() - schema = inspect(engine).default_schema_name - - perm_data = ROLE_ALL_PERM_DATA.copy() - perm_data["database"][0]["schema"][0]["name"] = schema - response = self.client.post( - "/superset/override_role_permissions/", - data=json.dumps(ROLE_ALL_PERM_DATA), - content_type="application/json", - ) - self.assertEqual(201, response.status_code) - - updated_role = security_manager.find_role("override_me") - perms = sorted(updated_role.permissions, key=lambda p: p.view_menu.name) - druid_ds_1 = self.get_druid_ds_by_name("druid_ds_1") - self.assertEqual(druid_ds_1.perm, perms[0].view_menu.name) - self.assertEqual("datasource_access", perms[0].permission.name) - - druid_ds_2 = self.get_druid_ds_by_name("druid_ds_2") - self.assertEqual(druid_ds_2.perm, perms[1].view_menu.name) - self.assertEqual( - "datasource_access", updated_role.permissions[1].permission.name - ) - - birth_names = self.get_table(name="birth_names") - self.assertEqual(birth_names.perm, perms[2].view_menu.name) - self.assertEqual( - "datasource_access", updated_role.permissions[2].permission.name - ) - self.assertEqual(3, len(perms)) - @pytest.mark.usefixtures( "load_energy_table_with_slice", "load_birth_names_dashboard_with_slices" ) @@ -596,56 +559,6 @@ def test_request_access(self): "".format(approve_link_3), ) - # Request druid access, there are no roles have this table. - druid_ds_4 = ( - session.query(DruidDatasource) - .filter_by(datasource_name="druid_ds_1") - .first() - ) - druid_ds_4_id = druid_ds_4.id - - # request access to the table - self.get_resp(ACCESS_REQUEST.format("druid", druid_ds_4_id, "go")) - access_request4 = self.get_access_requests("gamma", "druid", druid_ds_4_id) - - self.assertEqual(access_request4.roles_with_datasource, "") - - # Case 5. Roles exist that contains the druid datasource. - # add druid ds to the existing roles - druid_ds_5 = ( - session.query(DruidDatasource) - .filter_by(datasource_name="druid_ds_2") - .first() - ) - druid_ds_5_id = druid_ds_5.id - druid_ds_5_perm = druid_ds_5.perm - - druid_ds_2_role = security_manager.add_role("druid_ds_2_role") - admin_role = security_manager.find_role("Admin") - security_manager.add_permission_role( - admin_role, - security_manager.find_permission_view_menu( - "datasource_access", druid_ds_5_perm - ), - ) - security_manager.add_permission_role( - druid_ds_2_role, - security_manager.find_permission_view_menu( - "datasource_access", druid_ds_5_perm - ), - ) - session.commit() - - self.get_resp(ACCESS_REQUEST.format("druid", druid_ds_5_id, "go")) - access_request5 = self.get_access_requests("gamma", "druid", druid_ds_5_id) - approve_link_5 = ROLE_GRANT_LINK.format( - "druid", druid_ds_5_id, "gamma", "druid_ds_2_role", "druid_ds_2_role" - ) - self.assertEqual( - access_request5.roles_with_datasource, - "".format(approve_link_5), - ) - # cleanup gamma_user = security_manager.find_user(username="gamma") gamma_user.roles.remove(security_manager.find_role("dummy_role")) diff --git a/tests/integration_tests/base_tests.py b/tests/integration_tests/base_tests.py index fcd5e0908854..c2d2ef990e05 100644 --- a/tests/integration_tests/base_tests.py +++ b/tests/integration_tests/base_tests.py @@ -38,7 +38,6 @@ from superset.sql_parse import CtasMethod from superset import db, security_manager from superset.connectors.base.models import BaseDatasource -from superset.connectors.druid.models import DruidCluster, DruidDatasource from superset.connectors.sqla.models import SqlaTable from superset.models import core as models from superset.models.slice import Slice @@ -153,7 +152,7 @@ def create_user_with_roles( user_to_create.roles = [] for chosen_user_role in roles: if should_create_roles: - ## copy role from gamma but without data permissions + # copy role from gamma but without data permissions security_manager.copy_role("Gamma", chosen_user_role, merge=False) user_to_create.roles.append(security_manager.find_role(chosen_user_role)) db.session.commit() @@ -191,30 +190,6 @@ def get_role(name: str) -> Optional[ab_models.User]: ) return user - @classmethod - def create_druid_test_objects(cls): - # create druid cluster and druid datasources - - with app.app_context(): - session = db.session - cluster = ( - session.query(DruidCluster).filter_by(cluster_name="druid_test").first() - ) - if not cluster: - cluster = DruidCluster(cluster_name="druid_test") - session.add(cluster) - session.commit() - - druid_datasource1 = DruidDatasource( - datasource_name="druid_ds_1", cluster=cluster - ) - session.add(druid_datasource1) - druid_datasource2 = DruidDatasource( - datasource_name="druid_ds_2", cluster=cluster - ) - session.add(druid_datasource2) - session.commit() - @staticmethod def get_table_by_id(table_id: int) -> SqlaTable: return db.session.query(SqlaTable).filter_by(id=table_id).one() @@ -275,10 +250,6 @@ def get_database_by_name(database_name: str = "main") -> Database: else: raise ValueError("Database doesn't exist") - @staticmethod - def get_druid_ds_by_name(name: str) -> DruidDatasource: - return db.session.query(DruidDatasource).filter_by(datasource_name=name).first() - @staticmethod def get_datasource_mock() -> BaseDatasource: datasource = Mock() diff --git a/tests/integration_tests/core_tests.py b/tests/integration_tests/core_tests.py index 38e9b4ab0e01..df1e01c5a89c 100644 --- a/tests/integration_tests/core_tests.py +++ b/tests/integration_tests/core_tests.py @@ -294,7 +294,6 @@ def test_admin_only_permissions(self): def assert_admin_permission_in(role_name, assert_func): role = security_manager.find_role(role_name) permissions = [p.permission.name for p in role.permissions] - assert_func("can_sync_druid_source", permissions) assert_func("can_approve", permissions) assert_admin_permission_in("Admin", self.assertIn) diff --git a/tests/integration_tests/dict_import_export_tests.py b/tests/integration_tests/dict_import_export_tests.py index bd9f79e4aed4..de0aa832626a 100644 --- a/tests/integration_tests/dict_import_export_tests.py +++ b/tests/integration_tests/dict_import_export_tests.py @@ -24,12 +24,7 @@ from tests.integration_tests.test_app import app from superset import db -from superset.connectors.druid.models import ( - DruidColumn, - DruidDatasource, - DruidMetric, - DruidCluster, -) + from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn from superset.utils.database import get_example_database from superset.utils.dict_import_export import export_to_dict @@ -52,9 +47,6 @@ def delete_imports(cls): for table in session.query(SqlaTable): if DBREF in table.params_dict: session.delete(table) - for datasource in session.query(DruidDatasource): - if DBREF in datasource.params_dict: - session.delete(datasource) session.commit() @classmethod @@ -96,38 +88,6 @@ def create_table( table.metrics.append(SqlMetric(metric_name=metric_name, expression="")) return table, dict_rep - def create_druid_datasource(self, name, id=0, cols_names=[], metric_names=[]): - cluster_name = "druid_test" - cluster = self.get_or_create( - DruidCluster, {"cluster_name": cluster_name}, db.session - ) - - name = "{0}{1}".format(NAME_PREFIX, name) - params = {DBREF: id, "database_name": cluster_name} - dict_rep = { - "cluster_id": cluster.id, - "datasource_name": name, - "id": id, - "params": json.dumps(params), - "columns": [{"column_name": c} for c in cols_names], - "metrics": [{"metric_name": c, "json": "{}"} for c in metric_names], - } - - datasource = DruidDatasource( - id=id, - datasource_name=name, - cluster_id=cluster.id, - params=json.dumps(params), - ) - for col_name in cols_names: - datasource.columns.append(DruidColumn(column_name=col_name)) - for metric_name in metric_names: - datasource.metrics.append(DruidMetric(metric_name=metric_name)) - return datasource, dict_rep - - def get_datasource(self, datasource_id): - return db.session.query(DruidDatasource).filter_by(id=datasource_id).first() - def yaml_compare(self, obj_1, obj_2): obj_1_str = yaml.safe_dump(obj_1, default_flow_style=False) obj_2_str = yaml.safe_dump(obj_2, default_flow_style=False) @@ -308,118 +268,6 @@ def test_export_datasource_ui_cli(self): ui_export["databases"][0]["tables"], cli_export["databases"][0]["tables"] ) - def test_import_druid_no_metadata(self): - datasource, dict_datasource = self.create_druid_datasource( - "pure_druid", id=ID_PREFIX + 1 - ) - imported_cluster = DruidDatasource.import_from_dict(db.session, dict_datasource) - db.session.commit() - imported = self.get_datasource(imported_cluster.id) - self.assert_datasource_equals(datasource, imported) - - def test_import_druid_1_col_1_met(self): - datasource, dict_datasource = self.create_druid_datasource( - "druid_1_col_1_met", - id=ID_PREFIX + 2, - cols_names=["col1"], - metric_names=["metric1"], - ) - imported_cluster = DruidDatasource.import_from_dict(db.session, dict_datasource) - db.session.commit() - imported = self.get_datasource(imported_cluster.id) - self.assert_datasource_equals(datasource, imported) - self.assertEqual( - {DBREF: ID_PREFIX + 2, "database_name": "druid_test"}, - json.loads(imported.params), - ) - - def test_import_druid_2_col_2_met(self): - datasource, dict_datasource = self.create_druid_datasource( - "druid_2_col_2_met", - id=ID_PREFIX + 3, - cols_names=["c1", "c2"], - metric_names=["m1", "m2"], - ) - imported_cluster = DruidDatasource.import_from_dict(db.session, dict_datasource) - db.session.commit() - imported = self.get_datasource(imported_cluster.id) - self.assert_datasource_equals(datasource, imported) - - def test_import_druid_override_append(self): - datasource, dict_datasource = self.create_druid_datasource( - "druid_override", id=ID_PREFIX + 3, cols_names=["col1"], metric_names=["m1"] - ) - imported_cluster = DruidDatasource.import_from_dict(db.session, dict_datasource) - db.session.commit() - table_over, table_over_dict = self.create_druid_datasource( - "druid_override", - id=ID_PREFIX + 3, - cols_names=["new_col1", "col2", "col3"], - metric_names=["new_metric1"], - ) - imported_over_cluster = DruidDatasource.import_from_dict( - db.session, table_over_dict - ) - db.session.commit() - imported_over = self.get_datasource(imported_over_cluster.id) - self.assertEqual(imported_cluster.id, imported_over.id) - expected_datasource, _ = self.create_druid_datasource( - "druid_override", - id=ID_PREFIX + 3, - metric_names=["new_metric1", "m1"], - cols_names=["col1", "new_col1", "col2", "col3"], - ) - self.assert_datasource_equals(expected_datasource, imported_over) - - def test_import_druid_override_sync(self): - datasource, dict_datasource = self.create_druid_datasource( - "druid_override", id=ID_PREFIX + 3, cols_names=["col1"], metric_names=["m1"] - ) - imported_cluster = DruidDatasource.import_from_dict(db.session, dict_datasource) - db.session.commit() - table_over, table_over_dict = self.create_druid_datasource( - "druid_override", - id=ID_PREFIX + 3, - cols_names=["new_col1", "col2", "col3"], - metric_names=["new_metric1"], - ) - imported_over_cluster = DruidDatasource.import_from_dict( - session=db.session, dict_rep=table_over_dict, sync=["metrics", "columns"] - ) # syncing metrics and columns - db.session.commit() - imported_over = self.get_datasource(imported_over_cluster.id) - self.assertEqual(imported_cluster.id, imported_over.id) - expected_datasource, _ = self.create_druid_datasource( - "druid_override", - id=ID_PREFIX + 3, - metric_names=["new_metric1"], - cols_names=["new_col1", "col2", "col3"], - ) - self.assert_datasource_equals(expected_datasource, imported_over) - - def test_import_druid_override_identical(self): - datasource, dict_datasource = self.create_druid_datasource( - "copy_cat", - id=ID_PREFIX + 4, - cols_names=["new_col1", "col2", "col3"], - metric_names=["new_metric1"], - ) - imported = DruidDatasource.import_from_dict( - session=db.session, dict_rep=dict_datasource - ) - db.session.commit() - copy_datasource, dict_cp_datasource = self.create_druid_datasource( - "copy_cat", - id=ID_PREFIX + 4, - cols_names=["new_col1", "col2", "col3"], - metric_names=["new_metric1"], - ) - imported_copy = DruidDatasource.import_from_dict(db.session, dict_cp_datasource) - db.session.commit() - - self.assertEqual(imported.id, imported_copy.id) - self.assert_datasource_equals(copy_datasource, self.get_datasource(imported.id)) - if __name__ == "__main__": unittest.main() diff --git a/tests/integration_tests/druid_func_tests.py b/tests/integration_tests/druid_func_tests.py deleted file mode 100644 index 7227f485b219..000000000000 --- a/tests/integration_tests/druid_func_tests.py +++ /dev/null @@ -1,1152 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# isort:skip_file -import json -import unittest -from unittest.mock import Mock - -import tests.integration_tests.test_app -import superset.connectors.druid.models as models -from superset.connectors.druid.models import DruidColumn, DruidDatasource, DruidMetric -from superset.exceptions import SupersetException - -from .base_tests import SupersetTestCase - -try: - from pydruid.utils.dimensions import ( - MapLookupExtraction, - RegexExtraction, - RegisteredLookupExtraction, - TimeFormatExtraction, - ) - import pydruid.utils.postaggregator as postaggs -except ImportError: - pass - - -def mock_metric(metric_name, is_postagg=False): - metric = Mock() - metric.metric_name = metric_name - metric.metric_type = "postagg" if is_postagg else "metric" - return metric - - -def emplace(metrics_dict, metric_name, is_postagg=False): - metrics_dict[metric_name] = mock_metric(metric_name, is_postagg) - - -# Unit tests that can be run without initializing base tests -class TestDruidFunc(SupersetTestCase): - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_get_filters_extraction_fn_map(self): - filters = [{"col": "deviceName", "val": ["iPhone X"], "op": "in"}] - dimension_spec = { - "type": "extraction", - "dimension": "device", - "outputName": "deviceName", - "outputType": "STRING", - "extractionFn": { - "type": "lookup", - "dimension": "dimensionName", - "outputName": "dimensionOutputName", - "replaceMissingValueWith": "missing_value", - "retainMissingValue": False, - "lookup": { - "type": "map", - "map": { - "iPhone10,1": "iPhone 8", - "iPhone10,4": "iPhone 8", - "iPhone10,2": "iPhone 8 Plus", - "iPhone10,5": "iPhone 8 Plus", - "iPhone10,3": "iPhone X", - "iPhone10,6": "iPhone X", - }, - "isOneToOne": False, - }, - }, - } - spec_json = json.dumps(dimension_spec) - col = DruidColumn(column_name="deviceName", dimension_spec_json=spec_json) - column_dict = {"deviceName": col} - f = DruidDatasource.get_filters(filters, [], column_dict) - assert isinstance(f.extraction_function, MapLookupExtraction) - dim_ext_fn = dimension_spec["extractionFn"] - f_ext_fn = f.extraction_function - self.assertEqual(dim_ext_fn["lookup"]["map"], f_ext_fn._mapping) - self.assertEqual(dim_ext_fn["lookup"]["isOneToOne"], f_ext_fn._injective) - self.assertEqual( - dim_ext_fn["replaceMissingValueWith"], f_ext_fn._replace_missing_values - ) - self.assertEqual( - dim_ext_fn["retainMissingValue"], f_ext_fn._retain_missing_values - ) - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_get_filters_extraction_fn_regex(self): - filters = [{"col": "buildPrefix", "val": ["22B"], "op": "in"}] - dimension_spec = { - "type": "extraction", - "dimension": "build", - "outputName": "buildPrefix", - "outputType": "STRING", - "extractionFn": {"type": "regex", "expr": "(^[0-9A-Za-z]{3})"}, - } - spec_json = json.dumps(dimension_spec) - col = DruidColumn(column_name="buildPrefix", dimension_spec_json=spec_json) - column_dict = {"buildPrefix": col} - f = DruidDatasource.get_filters(filters, [], column_dict) - assert isinstance(f.extraction_function, RegexExtraction) - dim_ext_fn = dimension_spec["extractionFn"] - f_ext_fn = f.extraction_function - self.assertEqual(dim_ext_fn["expr"], f_ext_fn._expr) - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_get_filters_extraction_fn_registered_lookup_extraction(self): - filters = [{"col": "country", "val": ["Spain"], "op": "in"}] - dimension_spec = { - "type": "extraction", - "dimension": "country_name", - "outputName": "country", - "outputType": "STRING", - "extractionFn": {"type": "registeredLookup", "lookup": "country_name"}, - } - spec_json = json.dumps(dimension_spec) - col = DruidColumn(column_name="country", dimension_spec_json=spec_json) - column_dict = {"country": col} - f = DruidDatasource.get_filters(filters, [], column_dict) - assert isinstance(f.extraction_function, RegisteredLookupExtraction) - dim_ext_fn = dimension_spec["extractionFn"] - self.assertEqual(dim_ext_fn["type"], f.extraction_function.extraction_type) - self.assertEqual(dim_ext_fn["lookup"], f.extraction_function._lookup) - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_get_filters_extraction_fn_time_format(self): - filters = [{"col": "dayOfMonth", "val": ["1", "20"], "op": "in"}] - dimension_spec = { - "type": "extraction", - "dimension": "__time", - "outputName": "dayOfMonth", - "extractionFn": { - "type": "timeFormat", - "format": "d", - "timeZone": "Asia/Kolkata", - "locale": "en", - }, - } - spec_json = json.dumps(dimension_spec) - col = DruidColumn(column_name="dayOfMonth", dimension_spec_json=spec_json) - column_dict = {"dayOfMonth": col} - f = DruidDatasource.get_filters(filters, [], column_dict) - assert isinstance(f.extraction_function, TimeFormatExtraction) - dim_ext_fn = dimension_spec["extractionFn"] - self.assertEqual(dim_ext_fn["type"], f.extraction_function.extraction_type) - self.assertEqual(dim_ext_fn["format"], f.extraction_function._format) - self.assertEqual(dim_ext_fn["timeZone"], f.extraction_function._time_zone) - self.assertEqual(dim_ext_fn["locale"], f.extraction_function._locale) - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_get_filters_ignores_invalid_filter_objects(self): - filtr = {"col": "col1", "op": "=="} - filters = [filtr] - col = DruidColumn(column_name="col1") - column_dict = {"col1": col} - self.assertIsNone(DruidDatasource.get_filters(filters, [], column_dict)) - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_get_filters_constructs_filter_in(self): - filtr = {"col": "A", "op": "in", "val": ["a", "b", "c"]} - col = DruidColumn(column_name="A") - column_dict = {"A": col} - res = DruidDatasource.get_filters([filtr], [], column_dict) - self.assertIn("filter", res.filter) - self.assertIn("fields", res.filter["filter"]) - self.assertEqual("or", res.filter["filter"]["type"]) - self.assertEqual(3, len(res.filter["filter"]["fields"])) - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_get_filters_constructs_filter_not_in(self): - filtr = {"col": "A", "op": "not in", "val": ["a", "b", "c"]} - col = DruidColumn(column_name="A") - column_dict = {"A": col} - res = DruidDatasource.get_filters([filtr], [], column_dict) - self.assertIn("filter", res.filter) - self.assertIn("type", res.filter["filter"]) - self.assertEqual("not", res.filter["filter"]["type"]) - self.assertIn("field", res.filter["filter"]) - self.assertEqual( - 3, len(res.filter["filter"]["field"].filter["filter"]["fields"]) - ) - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_get_filters_constructs_filter_equals(self): - filtr = {"col": "A", "op": "==", "val": "h"} - col = DruidColumn(column_name="A") - column_dict = {"A": col} - res = DruidDatasource.get_filters([filtr], [], column_dict) - self.assertEqual("selector", res.filter["filter"]["type"]) - self.assertEqual("A", res.filter["filter"]["dimension"]) - self.assertEqual("h", res.filter["filter"]["value"]) - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_get_filters_constructs_filter_not_equals(self): - filtr = {"col": "A", "op": "!=", "val": "h"} - col = DruidColumn(column_name="A") - column_dict = {"A": col} - res = DruidDatasource.get_filters([filtr], [], column_dict) - self.assertEqual("not", res.filter["filter"]["type"]) - self.assertEqual("h", res.filter["filter"]["field"].filter["filter"]["value"]) - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_get_filters_constructs_bounds_filter(self): - filtr = {"col": "A", "op": ">=", "val": "h"} - col = DruidColumn(column_name="A") - column_dict = {"A": col} - res = DruidDatasource.get_filters([filtr], [], column_dict) - self.assertFalse(res.filter["filter"]["lowerStrict"]) - self.assertEqual("A", res.filter["filter"]["dimension"]) - self.assertEqual("h", res.filter["filter"]["lower"]) - self.assertEqual("lexicographic", res.filter["filter"]["ordering"]) - filtr["op"] = ">" - res = DruidDatasource.get_filters([filtr], [], column_dict) - self.assertTrue(res.filter["filter"]["lowerStrict"]) - filtr["op"] = "<=" - res = DruidDatasource.get_filters([filtr], [], column_dict) - self.assertFalse(res.filter["filter"]["upperStrict"]) - self.assertEqual("h", res.filter["filter"]["upper"]) - filtr["op"] = "<" - res = DruidDatasource.get_filters([filtr], [], column_dict) - self.assertTrue(res.filter["filter"]["upperStrict"]) - filtr["val"] = 1 - res = DruidDatasource.get_filters([filtr], ["A"], column_dict) - self.assertEqual("numeric", res.filter["filter"]["ordering"]) - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_get_filters_is_null_filter(self): - filtr = {"col": "A", "op": "IS NULL"} - col = DruidColumn(column_name="A") - column_dict = {"A": col} - res = DruidDatasource.get_filters([filtr], [], column_dict) - self.assertEqual("selector", res.filter["filter"]["type"]) - self.assertEqual("", res.filter["filter"]["value"]) - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_get_filters_is_not_null_filter(self): - filtr = {"col": "A", "op": "IS NOT NULL"} - col = DruidColumn(column_name="A") - column_dict = {"A": col} - res = DruidDatasource.get_filters([filtr], [], column_dict) - self.assertEqual("not", res.filter["filter"]["type"]) - self.assertIn("field", res.filter["filter"]) - self.assertEqual( - "selector", res.filter["filter"]["field"].filter["filter"]["type"] - ) - self.assertEqual("", res.filter["filter"]["field"].filter["filter"]["value"]) - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_get_filters_constructs_regex_filter(self): - filtr = {"col": "A", "op": "regex", "val": "[abc]"} - col = DruidColumn(column_name="A") - column_dict = {"A": col} - res = DruidDatasource.get_filters([filtr], [], column_dict) - self.assertEqual("regex", res.filter["filter"]["type"]) - self.assertEqual("[abc]", res.filter["filter"]["pattern"]) - self.assertEqual("A", res.filter["filter"]["dimension"]) - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_get_filters_composes_multiple_filters(self): - filtr1 = {"col": "A", "op": "!=", "val": "y"} - filtr2 = {"col": "B", "op": "in", "val": ["a", "b", "c"]} - cola = DruidColumn(column_name="A") - colb = DruidColumn(column_name="B") - column_dict = {"A": cola, "B": colb} - res = DruidDatasource.get_filters([filtr1, filtr2], [], column_dict) - self.assertEqual("and", res.filter["filter"]["type"]) - self.assertEqual(2, len(res.filter["filter"]["fields"])) - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_get_filters_ignores_in_not_in_with_empty_value(self): - filtr1 = {"col": "A", "op": "in", "val": []} - filtr2 = {"col": "A", "op": "not in", "val": []} - col = DruidColumn(column_name="A") - column_dict = {"A": col} - res = DruidDatasource.get_filters([filtr1, filtr2], [], column_dict) - self.assertIsNone(res) - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_get_filters_constructs_equals_for_in_not_in_single_value(self): - filtr = {"col": "A", "op": "in", "val": ["a"]} - cola = DruidColumn(column_name="A") - colb = DruidColumn(column_name="B") - column_dict = {"A": cola, "B": colb} - res = DruidDatasource.get_filters([filtr], [], column_dict) - self.assertEqual("selector", res.filter["filter"]["type"]) - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_get_filters_handles_arrays_for_string_types(self): - filtr = {"col": "A", "op": "==", "val": ["a", "b"]} - col = DruidColumn(column_name="A") - column_dict = {"A": col} - res = DruidDatasource.get_filters([filtr], [], column_dict) - self.assertEqual("a", res.filter["filter"]["value"]) - - filtr = {"col": "A", "op": "==", "val": []} - res = DruidDatasource.get_filters([filtr], [], column_dict) - self.assertIsNone(res.filter["filter"]["value"]) - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_get_filters_handles_none_for_string_types(self): - filtr = {"col": "A", "op": "==", "val": None} - col = DruidColumn(column_name="A") - column_dict = {"A": col} - res = DruidDatasource.get_filters([filtr], [], column_dict) - self.assertIsNone(res) - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_get_filters_extracts_values_in_quotes(self): - filtr = {"col": "A", "op": "in", "val": ['"a"']} - col = DruidColumn(column_name="A") - column_dict = {"A": col} - res = DruidDatasource.get_filters([filtr], [], column_dict) - self.assertEqual('"a"', res.filter["filter"]["value"]) - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_get_filters_keeps_trailing_spaces(self): - filtr = {"col": "A", "op": "in", "val": ["a "]} - col = DruidColumn(column_name="A") - column_dict = {"A": col} - res = DruidDatasource.get_filters([filtr], [], column_dict) - self.assertEqual("a ", res.filter["filter"]["value"]) - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_get_filters_converts_strings_to_num(self): - filtr = {"col": "A", "op": "in", "val": ["6"]} - col = DruidColumn(column_name="A") - column_dict = {"A": col} - res = DruidDatasource.get_filters([filtr], ["A"], column_dict) - self.assertEqual(6, res.filter["filter"]["value"]) - filtr = {"col": "A", "op": "==", "val": "6"} - res = DruidDatasource.get_filters([filtr], ["A"], column_dict) - self.assertEqual(6, res.filter["filter"]["value"]) - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_run_query_no_groupby(self): - client = Mock() - from_dttm = Mock() - to_dttm = Mock() - from_dttm.replace = Mock(return_value=from_dttm) - to_dttm.replace = Mock(return_value=to_dttm) - from_dttm.isoformat = Mock(return_value="from") - to_dttm.isoformat = Mock(return_value="to") - timezone = "timezone" - from_dttm.tzname = Mock(return_value=timezone) - ds = DruidDatasource(datasource_name="datasource") - metric1 = DruidMetric(metric_name="metric1") - metric2 = DruidMetric(metric_name="metric2") - ds.metrics = [metric1, metric2] - col1 = DruidColumn(column_name="col1") - col2 = DruidColumn(column_name="col2") - ds.columns = [col1, col2] - aggs = [] - post_aggs = ["some_agg"] - ds._metrics_and_post_aggs = Mock(return_value=(aggs, post_aggs)) - groupby = [] - metrics = ["metric1"] - ds.get_having_filters = Mock(return_value=[]) - client.query_builder = Mock() - client.query_builder.last_query = Mock() - client.query_builder.last_query.query_dict = {"mock": 0} - # no groupby calls client.timeseries - ds.run_query( - metrics, - None, - from_dttm, - to_dttm, - groupby=groupby, - client=client, - filter=[], - row_limit=100, - ) - self.assertEqual(0, len(client.topn.call_args_list)) - self.assertEqual(0, len(client.groupby.call_args_list)) - self.assertEqual(1, len(client.timeseries.call_args_list)) - # check that there is no dimensions entry - called_args = client.timeseries.call_args_list[0][1] - self.assertNotIn("dimensions", called_args) - self.assertIn("post_aggregations", called_args) - # restore functions - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_run_query_with_adhoc_metric(self): - client = Mock() - from_dttm = Mock() - to_dttm = Mock() - from_dttm.replace = Mock(return_value=from_dttm) - to_dttm.replace = Mock(return_value=to_dttm) - from_dttm.isoformat = Mock(return_value="from") - to_dttm.isoformat = Mock(return_value="to") - timezone = "timezone" - from_dttm.tzname = Mock(return_value=timezone) - ds = DruidDatasource(datasource_name="datasource") - metric1 = DruidMetric(metric_name="metric1") - metric2 = DruidMetric(metric_name="metric2") - ds.metrics = [metric1, metric2] - col1 = DruidColumn(column_name="col1") - col2 = DruidColumn(column_name="col2") - ds.columns = [col1, col2] - all_metrics = [] - post_aggs = ["some_agg"] - ds._metrics_and_post_aggs = Mock(return_value=(all_metrics, post_aggs)) - groupby = [] - metrics = [ - { - "expressionType": "SIMPLE", - "column": {"type": "DOUBLE", "column_name": "col1"}, - "aggregate": "SUM", - "label": "My Adhoc Metric", - } - ] - - ds.get_having_filters = Mock(return_value=[]) - client.query_builder = Mock() - client.query_builder.last_query = Mock() - client.query_builder.last_query.query_dict = {"mock": 0} - # no groupby calls client.timeseries - ds.run_query( - metrics, - None, - from_dttm, - to_dttm, - groupby=groupby, - client=client, - filter=[], - row_limit=100, - ) - self.assertEqual(0, len(client.topn.call_args_list)) - self.assertEqual(0, len(client.groupby.call_args_list)) - self.assertEqual(1, len(client.timeseries.call_args_list)) - # check that there is no dimensions entry - called_args = client.timeseries.call_args_list[0][1] - self.assertNotIn("dimensions", called_args) - self.assertIn("post_aggregations", called_args) - # restore functions - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_run_query_single_groupby(self): - client = Mock() - from_dttm = Mock() - to_dttm = Mock() - from_dttm.replace = Mock(return_value=from_dttm) - to_dttm.replace = Mock(return_value=to_dttm) - from_dttm.isoformat = Mock(return_value="from") - to_dttm.isoformat = Mock(return_value="to") - timezone = "timezone" - from_dttm.tzname = Mock(return_value=timezone) - ds = DruidDatasource(datasource_name="datasource") - metric1 = DruidMetric(metric_name="metric1") - metric2 = DruidMetric(metric_name="metric2") - ds.metrics = [metric1, metric2] - col1 = DruidColumn(column_name="col1") - col2 = DruidColumn(column_name="col2") - ds.columns = [col1, col2] - aggs = ["metric1"] - post_aggs = ["some_agg"] - ds._metrics_and_post_aggs = Mock(return_value=(aggs, post_aggs)) - groupby = ["col1"] - metrics = ["metric1"] - ds.get_having_filters = Mock(return_value=[]) - client.query_builder.last_query.query_dict = {"mock": 0} - # client.topn is called twice - ds.run_query( - metrics, - None, - from_dttm, - to_dttm, - groupby=groupby, - timeseries_limit=100, - client=client, - order_desc=True, - filter=[], - ) - self.assertEqual(2, len(client.topn.call_args_list)) - self.assertEqual(0, len(client.groupby.call_args_list)) - self.assertEqual(0, len(client.timeseries.call_args_list)) - # check that there is no dimensions entry - called_args_pre = client.topn.call_args_list[0][1] - self.assertNotIn("dimensions", called_args_pre) - self.assertIn("dimension", called_args_pre) - called_args = client.topn.call_args_list[1][1] - self.assertIn("dimension", called_args) - self.assertEqual("col1", called_args["dimension"]) - # not order_desc - client = Mock() - client.query_builder.last_query.query_dict = {"mock": 0} - ds.run_query( - metrics, - None, - from_dttm, - to_dttm, - groupby=groupby, - client=client, - order_desc=False, - filter=[], - row_limit=100, - ) - self.assertEqual(0, len(client.topn.call_args_list)) - self.assertEqual(1, len(client.groupby.call_args_list)) - self.assertEqual(0, len(client.timeseries.call_args_list)) - self.assertIn("dimensions", client.groupby.call_args_list[0][1]) - self.assertEqual(["col1"], client.groupby.call_args_list[0][1]["dimensions"]) - # order_desc but timeseries and dimension spec - # calls topn with single dimension spec 'dimension' - spec = {"outputName": "hello", "dimension": "matcho"} - spec_json = json.dumps(spec) - col3 = DruidColumn(column_name="col3", dimension_spec_json=spec_json) - ds.columns.append(col3) - groupby = ["col3"] - client = Mock() - client.query_builder.last_query.query_dict = {"mock": 0} - ds.run_query( - metrics, - None, - from_dttm, - to_dttm, - groupby=groupby, - client=client, - order_desc=True, - timeseries_limit=5, - filter=[], - row_limit=100, - ) - self.assertEqual(2, len(client.topn.call_args_list)) - self.assertEqual(0, len(client.groupby.call_args_list)) - self.assertEqual(0, len(client.timeseries.call_args_list)) - self.assertIn("dimension", client.topn.call_args_list[0][1]) - self.assertIn("dimension", client.topn.call_args_list[1][1]) - # uses dimension for pre query and full spec for final query - self.assertEqual("matcho", client.topn.call_args_list[0][1]["dimension"]) - self.assertEqual(spec, client.topn.call_args_list[1][1]["dimension"]) - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_run_query_multiple_groupby(self): - client = Mock() - from_dttm = Mock() - to_dttm = Mock() - from_dttm.replace = Mock(return_value=from_dttm) - to_dttm.replace = Mock(return_value=to_dttm) - from_dttm.isoformat = Mock(return_value="from") - to_dttm.isoformat = Mock(return_value="to") - timezone = "timezone" - from_dttm.tzname = Mock(return_value=timezone) - ds = DruidDatasource(datasource_name="datasource") - metric1 = DruidMetric(metric_name="metric1") - metric2 = DruidMetric(metric_name="metric2") - ds.metrics = [metric1, metric2] - col1 = DruidColumn(column_name="col1") - col2 = DruidColumn(column_name="col2") - ds.columns = [col1, col2] - aggs = [] - post_aggs = ["some_agg"] - ds._metrics_and_post_aggs = Mock(return_value=(aggs, post_aggs)) - groupby = ["col1", "col2"] - metrics = ["metric1"] - ds.get_having_filters = Mock(return_value=[]) - client.query_builder = Mock() - client.query_builder.last_query = Mock() - client.query_builder.last_query.query_dict = {"mock": 0} - # no groupby calls client.timeseries - ds.run_query( - metrics, - None, - from_dttm, - to_dttm, - groupby=groupby, - client=client, - row_limit=100, - filter=[], - ) - self.assertEqual(0, len(client.topn.call_args_list)) - self.assertEqual(1, len(client.groupby.call_args_list)) - self.assertEqual(0, len(client.timeseries.call_args_list)) - # check that there is no dimensions entry - called_args = client.groupby.call_args_list[0][1] - self.assertIn("dimensions", called_args) - self.assertEqual(["col1", "col2"], called_args["dimensions"]) - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_get_post_agg_returns_correct_agg_type(self): - get_post_agg = DruidDatasource.get_post_agg - # javascript PostAggregators - function = "function(field1, field2) { return field1 + field2; }" - conf = { - "type": "javascript", - "name": "postagg_name", - "fieldNames": ["field1", "field2"], - "function": function, - } - postagg = get_post_agg(conf) - self.assertTrue(isinstance(postagg, models.JavascriptPostAggregator)) - self.assertEqual(postagg.name, "postagg_name") - self.assertEqual(postagg.post_aggregator["type"], "javascript") - self.assertEqual(postagg.post_aggregator["fieldNames"], ["field1", "field2"]) - self.assertEqual(postagg.post_aggregator["name"], "postagg_name") - self.assertEqual(postagg.post_aggregator["function"], function) - # Quantile - conf = {"type": "quantile", "name": "postagg_name", "probability": "0.5"} - postagg = get_post_agg(conf) - self.assertTrue(isinstance(postagg, postaggs.Quantile)) - self.assertEqual(postagg.name, "postagg_name") - self.assertEqual(postagg.post_aggregator["probability"], "0.5") - # Quantiles - conf = { - "type": "quantiles", - "name": "postagg_name", - "probabilities": "0.4,0.5,0.6", - } - postagg = get_post_agg(conf) - self.assertTrue(isinstance(postagg, postaggs.Quantiles)) - self.assertEqual(postagg.name, "postagg_name") - self.assertEqual(postagg.post_aggregator["probabilities"], "0.4,0.5,0.6") - # FieldAccess - conf = {"type": "fieldAccess", "name": "field_name"} - postagg = get_post_agg(conf) - self.assertTrue(isinstance(postagg, postaggs.Field)) - self.assertEqual(postagg.name, "field_name") - # constant - conf = {"type": "constant", "value": 1234, "name": "postagg_name"} - postagg = get_post_agg(conf) - self.assertTrue(isinstance(postagg, postaggs.Const)) - self.assertEqual(postagg.name, "postagg_name") - self.assertEqual(postagg.post_aggregator["value"], 1234) - # hyperUniqueCardinality - conf = {"type": "hyperUniqueCardinality", "name": "unique_name"} - postagg = get_post_agg(conf) - self.assertTrue(isinstance(postagg, postaggs.HyperUniqueCardinality)) - self.assertEqual(postagg.name, "unique_name") - # arithmetic - conf = { - "type": "arithmetic", - "fn": "+", - "fields": ["field1", "field2"], - "name": "postagg_name", - } - postagg = get_post_agg(conf) - self.assertTrue(isinstance(postagg, postaggs.Postaggregator)) - self.assertEqual(postagg.name, "postagg_name") - self.assertEqual(postagg.post_aggregator["fn"], "+") - self.assertEqual(postagg.post_aggregator["fields"], ["field1", "field2"]) - # custom post aggregator - conf = {"type": "custom", "name": "custom_name", "stuff": "more_stuff"} - postagg = get_post_agg(conf) - self.assertTrue(isinstance(postagg, models.CustomPostAggregator)) - self.assertEqual(postagg.name, "custom_name") - self.assertEqual(postagg.post_aggregator["stuff"], "more_stuff") - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_find_postaggs_for_returns_postaggs_and_removes(self): - find_postaggs_for = DruidDatasource.find_postaggs_for - postagg_names = set(["pa2", "pa3", "pa4", "m1", "m2", "m3", "m4"]) - - metrics = {} - for i in range(1, 6): - emplace(metrics, "pa" + str(i), True) - emplace(metrics, "m" + str(i), False) - postagg_list = find_postaggs_for(postagg_names, metrics) - self.assertEqual(3, len(postagg_list)) - self.assertEqual(4, len(postagg_names)) - expected_metrics = ["m1", "m2", "m3", "m4"] - expected_postaggs = set(["pa2", "pa3", "pa4"]) - for postagg in postagg_list: - expected_postaggs.remove(postagg.metric_name) - for metric in expected_metrics: - postagg_names.remove(metric) - self.assertEqual(0, len(expected_postaggs)) - self.assertEqual(0, len(postagg_names)) - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_recursive_get_fields(self): - conf = { - "type": "quantile", - "fieldName": "f1", - "field": { - "type": "custom", - "fields": [ - {"type": "fieldAccess", "fieldName": "f2"}, - {"type": "fieldAccess", "fieldName": "f3"}, - { - "type": "quantiles", - "fieldName": "f4", - "field": {"type": "custom"}, - }, - { - "type": "custom", - "fields": [ - {"type": "fieldAccess", "fieldName": "f5"}, - { - "type": "fieldAccess", - "fieldName": "f2", - "fields": [ - {"type": "fieldAccess", "fieldName": "f3"}, - {"type": "fieldIgnoreMe", "fieldName": "f6"}, - ], - }, - ], - }, - ], - }, - } - fields = DruidDatasource.recursive_get_fields(conf) - expected = set(["f1", "f2", "f3", "f4", "f5"]) - self.assertEqual(5, len(fields)) - for field in fields: - expected.remove(field) - self.assertEqual(0, len(expected)) - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_metrics_and_post_aggs_tree(self): - metrics = ["A", "B", "m1", "m2"] - metrics_dict = {} - for i in range(ord("A"), ord("K") + 1): - emplace(metrics_dict, chr(i), True) - for i in range(1, 10): - emplace(metrics_dict, "m" + str(i), False) - - def depends_on(index, fields): - dependents = fields if isinstance(fields, list) else [fields] - metrics_dict[index].json_obj = {"fieldNames": dependents} - - depends_on("A", ["m1", "D", "C"]) - depends_on("B", ["B", "C", "E", "F", "m3"]) - depends_on("C", ["H", "I"]) - depends_on("D", ["m2", "m5", "G", "C"]) - depends_on("E", ["H", "I", "J"]) - depends_on("F", ["J", "m5"]) - depends_on("G", ["m4", "m7", "m6", "A"]) - depends_on("H", ["A", "m4", "I"]) - depends_on("I", ["H", "K"]) - depends_on("J", "K") - depends_on("K", ["m8", "m9"]) - aggs, postaggs = DruidDatasource.metrics_and_post_aggs(metrics, metrics_dict) - expected_metrics = set(aggs.keys()) - self.assertEqual(9, len(aggs)) - for i in range(1, 10): - expected_metrics.remove("m" + str(i)) - self.assertEqual(0, len(expected_metrics)) - self.assertEqual(11, len(postaggs)) - for i in range(ord("A"), ord("K") + 1): - del postaggs[chr(i)] - self.assertEqual(0, len(postaggs)) - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_metrics_and_post_aggs(self): - """ - Test generation of metrics and post-aggregations from an initial list - of superset metrics (which may include the results of either). This - primarily tests that specifying a post-aggregator metric will also - require the raw aggregation of the associated druid metric column. - """ - metrics_dict = { - "unused_count": DruidMetric( - metric_name="unused_count", - verbose_name="COUNT(*)", - metric_type="count", - json=json.dumps({"type": "count", "name": "unused_count"}), - ), - "some_sum": DruidMetric( - metric_name="some_sum", - verbose_name="SUM(*)", - metric_type="sum", - json=json.dumps({"type": "sum", "name": "sum"}), - ), - "a_histogram": DruidMetric( - metric_name="a_histogram", - verbose_name="APPROXIMATE_HISTOGRAM(*)", - metric_type="approxHistogramFold", - json=json.dumps({"type": "approxHistogramFold", "name": "a_histogram"}), - ), - "aCustomMetric": DruidMetric( - metric_name="aCustomMetric", - verbose_name="MY_AWESOME_METRIC(*)", - metric_type="aCustomType", - json=json.dumps({"type": "customMetric", "name": "aCustomMetric"}), - ), - "quantile_p95": DruidMetric( - metric_name="quantile_p95", - verbose_name="P95(*)", - metric_type="postagg", - json=json.dumps( - { - "type": "quantile", - "probability": 0.95, - "name": "p95", - "fieldName": "a_histogram", - } - ), - ), - "aCustomPostAgg": DruidMetric( - metric_name="aCustomPostAgg", - verbose_name="CUSTOM_POST_AGG(*)", - metric_type="postagg", - json=json.dumps( - { - "type": "customPostAgg", - "name": "aCustomPostAgg", - "field": {"type": "fieldAccess", "fieldName": "aCustomMetric"}, - } - ), - ), - } - - adhoc_metric = { - "expressionType": "SIMPLE", - "column": {"type": "DOUBLE", "column_name": "value"}, - "aggregate": "SUM", - "label": "My Adhoc Metric", - } - - metrics = ["some_sum"] - saved_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs( - metrics, metrics_dict - ) - - assert set(saved_metrics.keys()) == {"some_sum"} - assert post_aggs == {} - - metrics = [adhoc_metric] - saved_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs( - metrics, metrics_dict - ) - - assert set(saved_metrics.keys()) == set([adhoc_metric["label"]]) - assert post_aggs == {} - - metrics = ["some_sum", adhoc_metric] - saved_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs( - metrics, metrics_dict - ) - - assert set(saved_metrics.keys()) == {"some_sum", adhoc_metric["label"]} - assert post_aggs == {} - - metrics = ["quantile_p95"] - saved_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs( - metrics, metrics_dict - ) - - result_postaggs = set(["quantile_p95"]) - assert set(saved_metrics.keys()) == {"a_histogram"} - assert set(post_aggs.keys()) == result_postaggs - - metrics = ["aCustomPostAgg"] - saved_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs( - metrics, metrics_dict - ) - - result_postaggs = set(["aCustomPostAgg"]) - assert set(saved_metrics.keys()) == {"aCustomMetric"} - assert set(post_aggs.keys()) == result_postaggs - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_druid_type_from_adhoc_metric(self): - - druid_type = DruidDatasource.druid_type_from_adhoc_metric( - { - "column": {"type": "DOUBLE", "column_name": "value"}, - "aggregate": "SUM", - "label": "My Adhoc Metric", - } - ) - assert druid_type == "doubleSum" - - druid_type = DruidDatasource.druid_type_from_adhoc_metric( - { - "column": {"type": "LONG", "column_name": "value"}, - "aggregate": "MAX", - "label": "My Adhoc Metric", - } - ) - assert druid_type == "longMax" - - druid_type = DruidDatasource.druid_type_from_adhoc_metric( - { - "column": {"type": "VARCHAR(255)", "column_name": "value"}, - "aggregate": "COUNT", - "label": "My Adhoc Metric", - } - ) - assert druid_type == "count" - - druid_type = DruidDatasource.druid_type_from_adhoc_metric( - { - "column": {"type": "VARCHAR(255)", "column_name": "value"}, - "aggregate": "COUNT_DISTINCT", - "label": "My Adhoc Metric", - } - ) - assert druid_type == "cardinality" - - druid_type = DruidDatasource.druid_type_from_adhoc_metric( - { - "column": {"type": "hyperUnique", "column_name": "value"}, - "aggregate": "COUNT_DISTINCT", - "label": "My Adhoc Metric", - } - ) - assert druid_type == "hyperUnique" - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_run_query_order_by_metrics(self): - client = Mock() - client.query_builder.last_query.query_dict = {"mock": 0} - from_dttm = Mock() - to_dttm = Mock() - ds = DruidDatasource(datasource_name="datasource") - ds.get_having_filters = Mock(return_value=[]) - dim1 = DruidColumn(column_name="dim1") - dim2 = DruidColumn(column_name="dim2") - metrics_dict = { - "count1": DruidMetric( - metric_name="count1", - metric_type="count", - json=json.dumps({"type": "count", "name": "count1"}), - ), - "sum1": DruidMetric( - metric_name="sum1", - metric_type="doubleSum", - json=json.dumps({"type": "doubleSum", "name": "sum1"}), - ), - "sum2": DruidMetric( - metric_name="sum2", - metric_type="doubleSum", - json=json.dumps({"type": "doubleSum", "name": "sum2"}), - ), - "div1": DruidMetric( - metric_name="div1", - metric_type="postagg", - json=json.dumps( - { - "fn": "/", - "type": "arithmetic", - "name": "div1", - "fields": [ - {"fieldName": "sum1", "type": "fieldAccess"}, - {"fieldName": "sum2", "type": "fieldAccess"}, - ], - } - ), - ), - } - ds.columns = [dim1, dim2] - ds.metrics = list(metrics_dict.values()) - - groupby = ["dim1"] - metrics = ["count1"] - granularity = "all" - # get the counts of the top 5 'dim1's, order by 'sum1' - ds.run_query( - metrics, - granularity, - from_dttm, - to_dttm, - groupby=groupby, - timeseries_limit=5, - timeseries_limit_metric="sum1", - client=client, - order_desc=True, - filter=[], - ) - qry_obj = client.topn.call_args_list[0][1] - self.assertEqual("dim1", qry_obj["dimension"]) - self.assertEqual("sum1", qry_obj["metric"]) - aggregations = qry_obj["aggregations"] - post_aggregations = qry_obj["post_aggregations"] - self.assertEqual({"count1", "sum1"}, set(aggregations.keys())) - self.assertEqual(set(), set(post_aggregations.keys())) - - # get the counts of the top 5 'dim1's, order by 'div1' - ds.run_query( - metrics, - granularity, - from_dttm, - to_dttm, - groupby=groupby, - timeseries_limit=5, - timeseries_limit_metric="div1", - client=client, - order_desc=True, - filter=[], - ) - qry_obj = client.topn.call_args_list[1][1] - self.assertEqual("dim1", qry_obj["dimension"]) - self.assertEqual("div1", qry_obj["metric"]) - aggregations = qry_obj["aggregations"] - post_aggregations = qry_obj["post_aggregations"] - self.assertEqual({"count1", "sum1", "sum2"}, set(aggregations.keys())) - self.assertEqual({"div1"}, set(post_aggregations.keys())) - - groupby = ["dim1", "dim2"] - # get the counts of the top 5 ['dim1', 'dim2']s, order by 'sum1' - ds.run_query( - metrics, - granularity, - from_dttm, - to_dttm, - groupby=groupby, - timeseries_limit=5, - timeseries_limit_metric="sum1", - client=client, - order_desc=True, - filter=[], - ) - qry_obj = client.groupby.call_args_list[0][1] - self.assertEqual({"dim1", "dim2"}, set(qry_obj["dimensions"])) - self.assertEqual("sum1", qry_obj["limit_spec"]["columns"][0]["dimension"]) - aggregations = qry_obj["aggregations"] - post_aggregations = qry_obj["post_aggregations"] - self.assertEqual({"count1", "sum1"}, set(aggregations.keys())) - self.assertEqual(set(), set(post_aggregations.keys())) - - # get the counts of the top 5 ['dim1', 'dim2']s, order by 'div1' - ds.run_query( - metrics, - granularity, - from_dttm, - to_dttm, - groupby=groupby, - timeseries_limit=5, - timeseries_limit_metric="div1", - client=client, - order_desc=True, - filter=[], - ) - qry_obj = client.groupby.call_args_list[1][1] - self.assertEqual({"dim1", "dim2"}, set(qry_obj["dimensions"])) - self.assertEqual("div1", qry_obj["limit_spec"]["columns"][0]["dimension"]) - aggregations = qry_obj["aggregations"] - post_aggregations = qry_obj["post_aggregations"] - self.assertEqual({"count1", "sum1", "sum2"}, set(aggregations.keys())) - self.assertEqual({"div1"}, set(post_aggregations.keys())) - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_get_aggregations(self): - ds = DruidDatasource(datasource_name="datasource") - metrics_dict = { - "sum1": DruidMetric( - metric_name="sum1", - metric_type="doubleSum", - json=json.dumps({"type": "doubleSum", "name": "sum1"}), - ), - "sum2": DruidMetric( - metric_name="sum2", - metric_type="doubleSum", - json=json.dumps({"type": "doubleSum", "name": "sum2"}), - ), - "div1": DruidMetric( - metric_name="div1", - metric_type="postagg", - json=json.dumps( - { - "fn": "/", - "type": "arithmetic", - "name": "div1", - "fields": [ - {"fieldName": "sum1", "type": "fieldAccess"}, - {"fieldName": "sum2", "type": "fieldAccess"}, - ], - } - ), - ), - } - metric_names = ["sum1", "sum2"] - aggs = ds.get_aggregations(metrics_dict, metric_names) - expected_agg = {name: metrics_dict[name].json_obj for name in metric_names} - self.assertEqual(expected_agg, aggs) - - metric_names = ["sum1", "col1"] - self.assertRaises( - SupersetException, ds.get_aggregations, metrics_dict, metric_names - ) - - metric_names = ["sum1", "div1"] - self.assertRaises( - SupersetException, ds.get_aggregations, metrics_dict, metric_names - ) diff --git a/tests/integration_tests/druid_func_tests_sip38.py b/tests/integration_tests/druid_func_tests_sip38.py deleted file mode 100644 index adc355ed83a6..000000000000 --- a/tests/integration_tests/druid_func_tests_sip38.py +++ /dev/null @@ -1,1157 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# isort:skip_file -import json -import unittest -from unittest.mock import Mock, patch - -import tests.integration_tests.test_app -import superset.connectors.druid.models as models -from superset.connectors.druid.models import DruidColumn, DruidDatasource, DruidMetric -from superset.exceptions import SupersetException - -from .base_tests import SupersetTestCase - -try: - from pydruid.utils.dimensions import ( - MapLookupExtraction, - RegexExtraction, - RegisteredLookupExtraction, - TimeFormatExtraction, - ) - import pydruid.utils.postaggregator as postaggs -except ImportError: - pass - - -def mock_metric(metric_name, is_postagg=False): - metric = Mock() - metric.metric_name = metric_name - metric.metric_type = "postagg" if is_postagg else "metric" - return metric - - -def emplace(metrics_dict, metric_name, is_postagg=False): - metrics_dict[metric_name] = mock_metric(metric_name, is_postagg) - - -# Unit tests that can be run without initializing base tests -@patch.dict( - "superset.extensions.feature_flag_manager._feature_flags", - {"SIP_38_VIZ_REARCHITECTURE": True}, - clear=True, -) -class TestDruidFunc(SupersetTestCase): - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_get_filters_extraction_fn_map(self): - filters = [{"col": "deviceName", "val": ["iPhone X"], "op": "in"}] - dimension_spec = { - "type": "extraction", - "dimension": "device", - "outputName": "deviceName", - "outputType": "STRING", - "extractionFn": { - "type": "lookup", - "dimension": "dimensionName", - "outputName": "dimensionOutputName", - "replaceMissingValueWith": "missing_value", - "retainMissingValue": False, - "lookup": { - "type": "map", - "map": { - "iPhone10,1": "iPhone 8", - "iPhone10,4": "iPhone 8", - "iPhone10,2": "iPhone 8 Plus", - "iPhone10,5": "iPhone 8 Plus", - "iPhone10,3": "iPhone X", - "iPhone10,6": "iPhone X", - }, - "isOneToOne": False, - }, - }, - } - spec_json = json.dumps(dimension_spec) - col = DruidColumn(column_name="deviceName", dimension_spec_json=spec_json) - column_dict = {"deviceName": col} - f = DruidDatasource.get_filters(filters, [], column_dict) - assert isinstance(f.extraction_function, MapLookupExtraction) - dim_ext_fn = dimension_spec["extractionFn"] - f_ext_fn = f.extraction_function - self.assertEqual(dim_ext_fn["lookup"]["map"], f_ext_fn._mapping) - self.assertEqual(dim_ext_fn["lookup"]["isOneToOne"], f_ext_fn._injective) - self.assertEqual( - dim_ext_fn["replaceMissingValueWith"], f_ext_fn._replace_missing_values - ) - self.assertEqual( - dim_ext_fn["retainMissingValue"], f_ext_fn._retain_missing_values - ) - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_get_filters_extraction_fn_regex(self): - filters = [{"col": "buildPrefix", "val": ["22B"], "op": "in"}] - dimension_spec = { - "type": "extraction", - "dimension": "build", - "outputName": "buildPrefix", - "outputType": "STRING", - "extractionFn": {"type": "regex", "expr": "(^[0-9A-Za-z]{3})"}, - } - spec_json = json.dumps(dimension_spec) - col = DruidColumn(column_name="buildPrefix", dimension_spec_json=spec_json) - column_dict = {"buildPrefix": col} - f = DruidDatasource.get_filters(filters, [], column_dict) - assert isinstance(f.extraction_function, RegexExtraction) - dim_ext_fn = dimension_spec["extractionFn"] - f_ext_fn = f.extraction_function - self.assertEqual(dim_ext_fn["expr"], f_ext_fn._expr) - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_get_filters_extraction_fn_registered_lookup_extraction(self): - filters = [{"col": "country", "val": ["Spain"], "op": "in"}] - dimension_spec = { - "type": "extraction", - "dimension": "country_name", - "outputName": "country", - "outputType": "STRING", - "extractionFn": {"type": "registeredLookup", "lookup": "country_name"}, - } - spec_json = json.dumps(dimension_spec) - col = DruidColumn(column_name="country", dimension_spec_json=spec_json) - column_dict = {"country": col} - f = DruidDatasource.get_filters(filters, [], column_dict) - assert isinstance(f.extraction_function, RegisteredLookupExtraction) - dim_ext_fn = dimension_spec["extractionFn"] - self.assertEqual(dim_ext_fn["type"], f.extraction_function.extraction_type) - self.assertEqual(dim_ext_fn["lookup"], f.extraction_function._lookup) - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_get_filters_extraction_fn_time_format(self): - filters = [{"col": "dayOfMonth", "val": ["1", "20"], "op": "in"}] - dimension_spec = { - "type": "extraction", - "dimension": "__time", - "outputName": "dayOfMonth", - "extractionFn": { - "type": "timeFormat", - "format": "d", - "timeZone": "Asia/Kolkata", - "locale": "en", - }, - } - spec_json = json.dumps(dimension_spec) - col = DruidColumn(column_name="dayOfMonth", dimension_spec_json=spec_json) - column_dict = {"dayOfMonth": col} - f = DruidDatasource.get_filters(filters, [], column_dict) - assert isinstance(f.extraction_function, TimeFormatExtraction) - dim_ext_fn = dimension_spec["extractionFn"] - self.assertEqual(dim_ext_fn["type"], f.extraction_function.extraction_type) - self.assertEqual(dim_ext_fn["format"], f.extraction_function._format) - self.assertEqual(dim_ext_fn["timeZone"], f.extraction_function._time_zone) - self.assertEqual(dim_ext_fn["locale"], f.extraction_function._locale) - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_get_filters_ignores_invalid_filter_objects(self): - filtr = {"col": "col1", "op": "=="} - filters = [filtr] - col = DruidColumn(column_name="col1") - column_dict = {"col1": col} - self.assertIsNone(DruidDatasource.get_filters(filters, [], column_dict)) - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_get_filters_constructs_filter_in(self): - filtr = {"col": "A", "op": "in", "val": ["a", "b", "c"]} - col = DruidColumn(column_name="A") - column_dict = {"A": col} - res = DruidDatasource.get_filters([filtr], [], column_dict) - self.assertIn("filter", res.filter) - self.assertIn("fields", res.filter["filter"]) - self.assertEqual("or", res.filter["filter"]["type"]) - self.assertEqual(3, len(res.filter["filter"]["fields"])) - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_get_filters_constructs_filter_not_in(self): - filtr = {"col": "A", "op": "not in", "val": ["a", "b", "c"]} - col = DruidColumn(column_name="A") - column_dict = {"A": col} - res = DruidDatasource.get_filters([filtr], [], column_dict) - self.assertIn("filter", res.filter) - self.assertIn("type", res.filter["filter"]) - self.assertEqual("not", res.filter["filter"]["type"]) - self.assertIn("field", res.filter["filter"]) - self.assertEqual( - 3, len(res.filter["filter"]["field"].filter["filter"]["fields"]) - ) - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_get_filters_constructs_filter_equals(self): - filtr = {"col": "A", "op": "==", "val": "h"} - col = DruidColumn(column_name="A") - column_dict = {"A": col} - res = DruidDatasource.get_filters([filtr], [], column_dict) - self.assertEqual("selector", res.filter["filter"]["type"]) - self.assertEqual("A", res.filter["filter"]["dimension"]) - self.assertEqual("h", res.filter["filter"]["value"]) - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_get_filters_constructs_filter_not_equals(self): - filtr = {"col": "A", "op": "!=", "val": "h"} - col = DruidColumn(column_name="A") - column_dict = {"A": col} - res = DruidDatasource.get_filters([filtr], [], column_dict) - self.assertEqual("not", res.filter["filter"]["type"]) - self.assertEqual("h", res.filter["filter"]["field"].filter["filter"]["value"]) - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_get_filters_constructs_bounds_filter(self): - filtr = {"col": "A", "op": ">=", "val": "h"} - col = DruidColumn(column_name="A") - column_dict = {"A": col} - res = DruidDatasource.get_filters([filtr], [], column_dict) - self.assertFalse(res.filter["filter"]["lowerStrict"]) - self.assertEqual("A", res.filter["filter"]["dimension"]) - self.assertEqual("h", res.filter["filter"]["lower"]) - self.assertEqual("lexicographic", res.filter["filter"]["ordering"]) - filtr["op"] = ">" - res = DruidDatasource.get_filters([filtr], [], column_dict) - self.assertTrue(res.filter["filter"]["lowerStrict"]) - filtr["op"] = "<=" - res = DruidDatasource.get_filters([filtr], [], column_dict) - self.assertFalse(res.filter["filter"]["upperStrict"]) - self.assertEqual("h", res.filter["filter"]["upper"]) - filtr["op"] = "<" - res = DruidDatasource.get_filters([filtr], [], column_dict) - self.assertTrue(res.filter["filter"]["upperStrict"]) - filtr["val"] = 1 - res = DruidDatasource.get_filters([filtr], ["A"], column_dict) - self.assertEqual("numeric", res.filter["filter"]["ordering"]) - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_get_filters_is_null_filter(self): - filtr = {"col": "A", "op": "IS NULL"} - col = DruidColumn(column_name="A") - column_dict = {"A": col} - res = DruidDatasource.get_filters([filtr], [], column_dict) - self.assertEqual("selector", res.filter["filter"]["type"]) - self.assertEqual("", res.filter["filter"]["value"]) - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_get_filters_is_not_null_filter(self): - filtr = {"col": "A", "op": "IS NOT NULL"} - col = DruidColumn(column_name="A") - column_dict = {"A": col} - res = DruidDatasource.get_filters([filtr], [], column_dict) - self.assertEqual("not", res.filter["filter"]["type"]) - self.assertIn("field", res.filter["filter"]) - self.assertEqual( - "selector", res.filter["filter"]["field"].filter["filter"]["type"] - ) - self.assertEqual("", res.filter["filter"]["field"].filter["filter"]["value"]) - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_get_filters_constructs_regex_filter(self): - filtr = {"col": "A", "op": "regex", "val": "[abc]"} - col = DruidColumn(column_name="A") - column_dict = {"A": col} - res = DruidDatasource.get_filters([filtr], [], column_dict) - self.assertEqual("regex", res.filter["filter"]["type"]) - self.assertEqual("[abc]", res.filter["filter"]["pattern"]) - self.assertEqual("A", res.filter["filter"]["dimension"]) - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_get_filters_composes_multiple_filters(self): - filtr1 = {"col": "A", "op": "!=", "val": "y"} - filtr2 = {"col": "B", "op": "in", "val": ["a", "b", "c"]} - cola = DruidColumn(column_name="A") - colb = DruidColumn(column_name="B") - column_dict = {"A": cola, "B": colb} - res = DruidDatasource.get_filters([filtr1, filtr2], [], column_dict) - self.assertEqual("and", res.filter["filter"]["type"]) - self.assertEqual(2, len(res.filter["filter"]["fields"])) - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_get_filters_ignores_in_not_in_with_empty_value(self): - filtr1 = {"col": "A", "op": "in", "val": []} - filtr2 = {"col": "A", "op": "not in", "val": []} - col = DruidColumn(column_name="A") - column_dict = {"A": col} - res = DruidDatasource.get_filters([filtr1, filtr2], [], column_dict) - self.assertIsNone(res) - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_get_filters_constructs_equals_for_in_not_in_single_value(self): - filtr = {"col": "A", "op": "in", "val": ["a"]} - cola = DruidColumn(column_name="A") - colb = DruidColumn(column_name="B") - column_dict = {"A": cola, "B": colb} - res = DruidDatasource.get_filters([filtr], [], column_dict) - self.assertEqual("selector", res.filter["filter"]["type"]) - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_get_filters_handles_arrays_for_string_types(self): - filtr = {"col": "A", "op": "==", "val": ["a", "b"]} - col = DruidColumn(column_name="A") - column_dict = {"A": col} - res = DruidDatasource.get_filters([filtr], [], column_dict) - self.assertEqual("a", res.filter["filter"]["value"]) - - filtr = {"col": "A", "op": "==", "val": []} - res = DruidDatasource.get_filters([filtr], [], column_dict) - self.assertIsNone(res.filter["filter"]["value"]) - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_get_filters_handles_none_for_string_types(self): - filtr = {"col": "A", "op": "==", "val": None} - col = DruidColumn(column_name="A") - column_dict = {"A": col} - res = DruidDatasource.get_filters([filtr], [], column_dict) - self.assertIsNone(res) - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_get_filters_extracts_values_in_quotes(self): - filtr = {"col": "A", "op": "in", "val": ['"a"']} - col = DruidColumn(column_name="A") - column_dict = {"A": col} - res = DruidDatasource.get_filters([filtr], [], column_dict) - self.assertEqual('"a"', res.filter["filter"]["value"]) - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_get_filters_keeps_trailing_spaces(self): - filtr = {"col": "A", "op": "in", "val": ["a "]} - col = DruidColumn(column_name="A") - column_dict = {"A": col} - res = DruidDatasource.get_filters([filtr], [], column_dict) - self.assertEqual("a ", res.filter["filter"]["value"]) - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_get_filters_converts_strings_to_num(self): - filtr = {"col": "A", "op": "in", "val": ["6"]} - col = DruidColumn(column_name="A") - column_dict = {"A": col} - res = DruidDatasource.get_filters([filtr], ["A"], column_dict) - self.assertEqual(6, res.filter["filter"]["value"]) - filtr = {"col": "A", "op": "==", "val": "6"} - res = DruidDatasource.get_filters([filtr], ["A"], column_dict) - self.assertEqual(6, res.filter["filter"]["value"]) - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_run_query_no_groupby(self): - client = Mock() - from_dttm = Mock() - to_dttm = Mock() - from_dttm.replace = Mock(return_value=from_dttm) - to_dttm.replace = Mock(return_value=to_dttm) - from_dttm.isoformat = Mock(return_value="from") - to_dttm.isoformat = Mock(return_value="to") - timezone = "timezone" - from_dttm.tzname = Mock(return_value=timezone) - ds = DruidDatasource(datasource_name="datasource") - metric1 = DruidMetric(metric_name="metric1") - metric2 = DruidMetric(metric_name="metric2") - ds.metrics = [metric1, metric2] - col1 = DruidColumn(column_name="col1") - col2 = DruidColumn(column_name="col2") - ds.columns = [col1, col2] - aggs = [] - post_aggs = ["some_agg"] - ds._metrics_and_post_aggs = Mock(return_value=(aggs, post_aggs)) - columns = [] - metrics = ["metric1"] - ds.get_having_filters = Mock(return_value=[]) - client.query_builder = Mock() - client.query_builder.last_query = Mock() - client.query_builder.last_query.query_dict = {"mock": 0} - # no groupby calls client.timeseries - ds.run_query( - metrics, - None, - from_dttm, - to_dttm, - groupby=columns, - client=client, - filter=[], - row_limit=100, - ) - self.assertEqual(0, len(client.topn.call_args_list)) - self.assertEqual(0, len(client.groupby.call_args_list)) - self.assertEqual(1, len(client.timeseries.call_args_list)) - # check that there is no dimensions entry - called_args = client.timeseries.call_args_list[0][1] - self.assertNotIn("dimensions", called_args) - self.assertIn("post_aggregations", called_args) - # restore functions - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_run_query_with_adhoc_metric(self): - client = Mock() - from_dttm = Mock() - to_dttm = Mock() - from_dttm.replace = Mock(return_value=from_dttm) - to_dttm.replace = Mock(return_value=to_dttm) - from_dttm.isoformat = Mock(return_value="from") - to_dttm.isoformat = Mock(return_value="to") - timezone = "timezone" - from_dttm.tzname = Mock(return_value=timezone) - ds = DruidDatasource(datasource_name="datasource") - metric1 = DruidMetric(metric_name="metric1") - metric2 = DruidMetric(metric_name="metric2") - ds.metrics = [metric1, metric2] - col1 = DruidColumn(column_name="col1") - col2 = DruidColumn(column_name="col2") - ds.columns = [col1, col2] - all_metrics = [] - post_aggs = ["some_agg"] - ds._metrics_and_post_aggs = Mock(return_value=(all_metrics, post_aggs)) - columns = [] - metrics = [ - { - "expressionType": "SIMPLE", - "column": {"type": "DOUBLE", "column_name": "col1"}, - "aggregate": "SUM", - "label": "My Adhoc Metric", - } - ] - - ds.get_having_filters = Mock(return_value=[]) - client.query_builder = Mock() - client.query_builder.last_query = Mock() - client.query_builder.last_query.query_dict = {"mock": 0} - # no groupby calls client.timeseries - ds.run_query( - metrics, - None, - from_dttm, - to_dttm, - groupby=columns, - client=client, - filter=[], - row_limit=100, - ) - self.assertEqual(0, len(client.topn.call_args_list)) - self.assertEqual(0, len(client.groupby.call_args_list)) - self.assertEqual(1, len(client.timeseries.call_args_list)) - # check that there is no dimensions entry - called_args = client.timeseries.call_args_list[0][1] - self.assertNotIn("dimensions", called_args) - self.assertIn("post_aggregations", called_args) - # restore functions - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_run_query_single_groupby(self): - client = Mock() - from_dttm = Mock() - to_dttm = Mock() - from_dttm.replace = Mock(return_value=from_dttm) - to_dttm.replace = Mock(return_value=to_dttm) - from_dttm.isoformat = Mock(return_value="from") - to_dttm.isoformat = Mock(return_value="to") - timezone = "timezone" - from_dttm.tzname = Mock(return_value=timezone) - ds = DruidDatasource(datasource_name="datasource") - metric1 = DruidMetric(metric_name="metric1") - metric2 = DruidMetric(metric_name="metric2") - ds.metrics = [metric1, metric2] - col1 = DruidColumn(column_name="col1") - col2 = DruidColumn(column_name="col2") - ds.columns = [col1, col2] - aggs = ["metric1"] - post_aggs = ["some_agg"] - ds._metrics_and_post_aggs = Mock(return_value=(aggs, post_aggs)) - columns = ["col1"] - metrics = ["metric1"] - ds.get_having_filters = Mock(return_value=[]) - client.query_builder.last_query.query_dict = {"mock": 0} - # client.topn is called twice - ds.run_query( - metrics, - None, - from_dttm, - to_dttm, - groupby=columns, - timeseries_limit=100, - client=client, - order_desc=True, - filter=[], - ) - self.assertEqual(2, len(client.topn.call_args_list)) - self.assertEqual(0, len(client.groupby.call_args_list)) - self.assertEqual(0, len(client.timeseries.call_args_list)) - # check that there is no dimensions entry - called_args_pre = client.topn.call_args_list[0][1] - self.assertNotIn("dimensions", called_args_pre) - self.assertIn("dimension", called_args_pre) - called_args = client.topn.call_args_list[1][1] - self.assertIn("dimension", called_args) - self.assertEqual("col1", called_args["dimension"]) - # not order_desc - client = Mock() - client.query_builder.last_query.query_dict = {"mock": 0} - ds.run_query( - metrics, - None, - from_dttm, - to_dttm, - groupby=columns, - client=client, - order_desc=False, - filter=[], - row_limit=100, - ) - self.assertEqual(0, len(client.topn.call_args_list)) - self.assertEqual(1, len(client.groupby.call_args_list)) - self.assertEqual(0, len(client.timeseries.call_args_list)) - self.assertIn("dimensions", client.groupby.call_args_list[0][1]) - self.assertEqual(["col1"], client.groupby.call_args_list[0][1]["dimensions"]) - # order_desc but timeseries and dimension spec - # calls topn with single dimension spec 'dimension' - spec = {"outputName": "hello", "dimension": "matcho"} - spec_json = json.dumps(spec) - col3 = DruidColumn(column_name="col3", dimension_spec_json=spec_json) - ds.columns.append(col3) - groupby = ["col3"] - client = Mock() - client.query_builder.last_query.query_dict = {"mock": 0} - ds.run_query( - metrics, - None, - from_dttm, - to_dttm, - groupby=groupby, - client=client, - order_desc=True, - timeseries_limit=5, - filter=[], - row_limit=100, - ) - self.assertEqual(2, len(client.topn.call_args_list)) - self.assertEqual(0, len(client.groupby.call_args_list)) - self.assertEqual(0, len(client.timeseries.call_args_list)) - self.assertIn("dimension", client.topn.call_args_list[0][1]) - self.assertIn("dimension", client.topn.call_args_list[1][1]) - # uses dimension for pre query and full spec for final query - self.assertEqual("matcho", client.topn.call_args_list[0][1]["dimension"]) - self.assertEqual(spec, client.topn.call_args_list[1][1]["dimension"]) - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_run_query_multiple_groupby(self): - client = Mock() - from_dttm = Mock() - to_dttm = Mock() - from_dttm.replace = Mock(return_value=from_dttm) - to_dttm.replace = Mock(return_value=to_dttm) - from_dttm.isoformat = Mock(return_value="from") - to_dttm.isoformat = Mock(return_value="to") - timezone = "timezone" - from_dttm.tzname = Mock(return_value=timezone) - ds = DruidDatasource(datasource_name="datasource") - metric1 = DruidMetric(metric_name="metric1") - metric2 = DruidMetric(metric_name="metric2") - ds.metrics = [metric1, metric2] - col1 = DruidColumn(column_name="col1") - col2 = DruidColumn(column_name="col2") - ds.columns = [col1, col2] - aggs = [] - post_aggs = ["some_agg"] - ds._metrics_and_post_aggs = Mock(return_value=(aggs, post_aggs)) - columns = ["col1", "col2"] - metrics = ["metric1"] - ds.get_having_filters = Mock(return_value=[]) - client.query_builder = Mock() - client.query_builder.last_query = Mock() - client.query_builder.last_query.query_dict = {"mock": 0} - # no groupby calls client.timeseries - ds.run_query( - metrics, - None, - from_dttm, - to_dttm, - groupby=columns, - client=client, - row_limit=100, - filter=[], - ) - self.assertEqual(0, len(client.topn.call_args_list)) - self.assertEqual(1, len(client.groupby.call_args_list)) - self.assertEqual(0, len(client.timeseries.call_args_list)) - # check that there is no dimensions entry - called_args = client.groupby.call_args_list[0][1] - self.assertIn("dimensions", called_args) - self.assertEqual(["col1", "col2"], called_args["dimensions"]) - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_get_post_agg_returns_correct_agg_type(self): - get_post_agg = DruidDatasource.get_post_agg - # javascript PostAggregators - function = "function(field1, field2) { return field1 + field2; }" - conf = { - "type": "javascript", - "name": "postagg_name", - "fieldNames": ["field1", "field2"], - "function": function, - } - postagg = get_post_agg(conf) - self.assertTrue(isinstance(postagg, models.JavascriptPostAggregator)) - self.assertEqual(postagg.name, "postagg_name") - self.assertEqual(postagg.post_aggregator["type"], "javascript") - self.assertEqual(postagg.post_aggregator["fieldNames"], ["field1", "field2"]) - self.assertEqual(postagg.post_aggregator["name"], "postagg_name") - self.assertEqual(postagg.post_aggregator["function"], function) - # Quantile - conf = {"type": "quantile", "name": "postagg_name", "probability": "0.5"} - postagg = get_post_agg(conf) - self.assertTrue(isinstance(postagg, postaggs.Quantile)) - self.assertEqual(postagg.name, "postagg_name") - self.assertEqual(postagg.post_aggregator["probability"], "0.5") - # Quantiles - conf = { - "type": "quantiles", - "name": "postagg_name", - "probabilities": "0.4,0.5,0.6", - } - postagg = get_post_agg(conf) - self.assertTrue(isinstance(postagg, postaggs.Quantiles)) - self.assertEqual(postagg.name, "postagg_name") - self.assertEqual(postagg.post_aggregator["probabilities"], "0.4,0.5,0.6") - # FieldAccess - conf = {"type": "fieldAccess", "name": "field_name"} - postagg = get_post_agg(conf) - self.assertTrue(isinstance(postagg, postaggs.Field)) - self.assertEqual(postagg.name, "field_name") - # constant - conf = {"type": "constant", "value": 1234, "name": "postagg_name"} - postagg = get_post_agg(conf) - self.assertTrue(isinstance(postagg, postaggs.Const)) - self.assertEqual(postagg.name, "postagg_name") - self.assertEqual(postagg.post_aggregator["value"], 1234) - # hyperUniqueCardinality - conf = {"type": "hyperUniqueCardinality", "name": "unique_name"} - postagg = get_post_agg(conf) - self.assertTrue(isinstance(postagg, postaggs.HyperUniqueCardinality)) - self.assertEqual(postagg.name, "unique_name") - # arithmetic - conf = { - "type": "arithmetic", - "fn": "+", - "fields": ["field1", "field2"], - "name": "postagg_name", - } - postagg = get_post_agg(conf) - self.assertTrue(isinstance(postagg, postaggs.Postaggregator)) - self.assertEqual(postagg.name, "postagg_name") - self.assertEqual(postagg.post_aggregator["fn"], "+") - self.assertEqual(postagg.post_aggregator["fields"], ["field1", "field2"]) - # custom post aggregator - conf = {"type": "custom", "name": "custom_name", "stuff": "more_stuff"} - postagg = get_post_agg(conf) - self.assertTrue(isinstance(postagg, models.CustomPostAggregator)) - self.assertEqual(postagg.name, "custom_name") - self.assertEqual(postagg.post_aggregator["stuff"], "more_stuff") - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_find_postaggs_for_returns_postaggs_and_removes(self): - find_postaggs_for = DruidDatasource.find_postaggs_for - postagg_names = set(["pa2", "pa3", "pa4", "m1", "m2", "m3", "m4"]) - - metrics = {} - for i in range(1, 6): - emplace(metrics, "pa" + str(i), True) - emplace(metrics, "m" + str(i), False) - postagg_list = find_postaggs_for(postagg_names, metrics) - self.assertEqual(3, len(postagg_list)) - self.assertEqual(4, len(postagg_names)) - expected_metrics = ["m1", "m2", "m3", "m4"] - expected_postaggs = set(["pa2", "pa3", "pa4"]) - for postagg in postagg_list: - expected_postaggs.remove(postagg.metric_name) - for metric in expected_metrics: - postagg_names.remove(metric) - self.assertEqual(0, len(expected_postaggs)) - self.assertEqual(0, len(postagg_names)) - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_recursive_get_fields(self): - conf = { - "type": "quantile", - "fieldName": "f1", - "field": { - "type": "custom", - "fields": [ - {"type": "fieldAccess", "fieldName": "f2"}, - {"type": "fieldAccess", "fieldName": "f3"}, - { - "type": "quantiles", - "fieldName": "f4", - "field": {"type": "custom"}, - }, - { - "type": "custom", - "fields": [ - {"type": "fieldAccess", "fieldName": "f5"}, - { - "type": "fieldAccess", - "fieldName": "f2", - "fields": [ - {"type": "fieldAccess", "fieldName": "f3"}, - {"type": "fieldIgnoreMe", "fieldName": "f6"}, - ], - }, - ], - }, - ], - }, - } - fields = DruidDatasource.recursive_get_fields(conf) - expected = set(["f1", "f2", "f3", "f4", "f5"]) - self.assertEqual(5, len(fields)) - for field in fields: - expected.remove(field) - self.assertEqual(0, len(expected)) - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_metrics_and_post_aggs_tree(self): - metrics = ["A", "B", "m1", "m2"] - metrics_dict = {} - for i in range(ord("A"), ord("K") + 1): - emplace(metrics_dict, chr(i), True) - for i in range(1, 10): - emplace(metrics_dict, "m" + str(i), False) - - def depends_on(index, fields): - dependents = fields if isinstance(fields, list) else [fields] - metrics_dict[index].json_obj = {"fieldNames": dependents} - - depends_on("A", ["m1", "D", "C"]) - depends_on("B", ["B", "C", "E", "F", "m3"]) - depends_on("C", ["H", "I"]) - depends_on("D", ["m2", "m5", "G", "C"]) - depends_on("E", ["H", "I", "J"]) - depends_on("F", ["J", "m5"]) - depends_on("G", ["m4", "m7", "m6", "A"]) - depends_on("H", ["A", "m4", "I"]) - depends_on("I", ["H", "K"]) - depends_on("J", "K") - depends_on("K", ["m8", "m9"]) - aggs, postaggs = DruidDatasource.metrics_and_post_aggs(metrics, metrics_dict) - expected_metrics = set(aggs.keys()) - self.assertEqual(9, len(aggs)) - for i in range(1, 10): - expected_metrics.remove("m" + str(i)) - self.assertEqual(0, len(expected_metrics)) - self.assertEqual(11, len(postaggs)) - for i in range(ord("A"), ord("K") + 1): - del postaggs[chr(i)] - self.assertEqual(0, len(postaggs)) - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_metrics_and_post_aggs(self): - """ - Test generation of metrics and post-aggregations from an initial list - of superset metrics (which may include the results of either). This - primarily tests that specifying a post-aggregator metric will also - require the raw aggregation of the associated druid metric column. - """ - metrics_dict = { - "unused_count": DruidMetric( - metric_name="unused_count", - verbose_name="COUNT(*)", - metric_type="count", - json=json.dumps({"type": "count", "name": "unused_count"}), - ), - "some_sum": DruidMetric( - metric_name="some_sum", - verbose_name="SUM(*)", - metric_type="sum", - json=json.dumps({"type": "sum", "name": "sum"}), - ), - "a_histogram": DruidMetric( - metric_name="a_histogram", - verbose_name="APPROXIMATE_HISTOGRAM(*)", - metric_type="approxHistogramFold", - json=json.dumps({"type": "approxHistogramFold", "name": "a_histogram"}), - ), - "aCustomMetric": DruidMetric( - metric_name="aCustomMetric", - verbose_name="MY_AWESOME_METRIC(*)", - metric_type="aCustomType", - json=json.dumps({"type": "customMetric", "name": "aCustomMetric"}), - ), - "quantile_p95": DruidMetric( - metric_name="quantile_p95", - verbose_name="P95(*)", - metric_type="postagg", - json=json.dumps( - { - "type": "quantile", - "probability": 0.95, - "name": "p95", - "fieldName": "a_histogram", - } - ), - ), - "aCustomPostAgg": DruidMetric( - metric_name="aCustomPostAgg", - verbose_name="CUSTOM_POST_AGG(*)", - metric_type="postagg", - json=json.dumps( - { - "type": "customPostAgg", - "name": "aCustomPostAgg", - "field": {"type": "fieldAccess", "fieldName": "aCustomMetric"}, - } - ), - ), - } - - adhoc_metric = { - "expressionType": "SIMPLE", - "column": {"type": "DOUBLE", "column_name": "value"}, - "aggregate": "SUM", - "label": "My Adhoc Metric", - } - - metrics = ["some_sum"] - saved_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs( - metrics, metrics_dict - ) - - assert set(saved_metrics.keys()) == {"some_sum"} - assert post_aggs == {} - - metrics = [adhoc_metric] - saved_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs( - metrics, metrics_dict - ) - - assert set(saved_metrics.keys()) == set([adhoc_metric["label"]]) - assert post_aggs == {} - - metrics = ["some_sum", adhoc_metric] - saved_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs( - metrics, metrics_dict - ) - - assert set(saved_metrics.keys()) == {"some_sum", adhoc_metric["label"]} - assert post_aggs == {} - - metrics = ["quantile_p95"] - saved_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs( - metrics, metrics_dict - ) - - result_postaggs = set(["quantile_p95"]) - assert set(saved_metrics.keys()) == {"a_histogram"} - assert set(post_aggs.keys()) == result_postaggs - - metrics = ["aCustomPostAgg"] - saved_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs( - metrics, metrics_dict - ) - - result_postaggs = set(["aCustomPostAgg"]) - assert set(saved_metrics.keys()) == {"aCustomMetric"} - assert set(post_aggs.keys()) == result_postaggs - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_druid_type_from_adhoc_metric(self): - - druid_type = DruidDatasource.druid_type_from_adhoc_metric( - { - "column": {"type": "DOUBLE", "column_name": "value"}, - "aggregate": "SUM", - "label": "My Adhoc Metric", - } - ) - assert druid_type == "doubleSum" - - druid_type = DruidDatasource.druid_type_from_adhoc_metric( - { - "column": {"type": "LONG", "column_name": "value"}, - "aggregate": "MAX", - "label": "My Adhoc Metric", - } - ) - assert druid_type == "longMax" - - druid_type = DruidDatasource.druid_type_from_adhoc_metric( - { - "column": {"type": "VARCHAR(255)", "column_name": "value"}, - "aggregate": "COUNT", - "label": "My Adhoc Metric", - } - ) - assert druid_type == "count" - - druid_type = DruidDatasource.druid_type_from_adhoc_metric( - { - "column": {"type": "VARCHAR(255)", "column_name": "value"}, - "aggregate": "COUNT_DISTINCT", - "label": "My Adhoc Metric", - } - ) - assert druid_type == "cardinality" - - druid_type = DruidDatasource.druid_type_from_adhoc_metric( - { - "column": {"type": "hyperUnique", "column_name": "value"}, - "aggregate": "COUNT_DISTINCT", - "label": "My Adhoc Metric", - } - ) - assert druid_type == "hyperUnique" - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_run_query_order_by_metrics(self): - client = Mock() - client.query_builder.last_query.query_dict = {"mock": 0} - from_dttm = Mock() - to_dttm = Mock() - ds = DruidDatasource(datasource_name="datasource") - ds.get_having_filters = Mock(return_value=[]) - dim1 = DruidColumn(column_name="dim1") - dim2 = DruidColumn(column_name="dim2") - metrics_dict = { - "count1": DruidMetric( - metric_name="count1", - metric_type="count", - json=json.dumps({"type": "count", "name": "count1"}), - ), - "sum1": DruidMetric( - metric_name="sum1", - metric_type="doubleSum", - json=json.dumps({"type": "doubleSum", "name": "sum1"}), - ), - "sum2": DruidMetric( - metric_name="sum2", - metric_type="doubleSum", - json=json.dumps({"type": "doubleSum", "name": "sum2"}), - ), - "div1": DruidMetric( - metric_name="div1", - metric_type="postagg", - json=json.dumps( - { - "fn": "/", - "type": "arithmetic", - "name": "div1", - "fields": [ - {"fieldName": "sum1", "type": "fieldAccess"}, - {"fieldName": "sum2", "type": "fieldAccess"}, - ], - } - ), - ), - } - ds.columns = [dim1, dim2] - ds.metrics = list(metrics_dict.values()) - - columns = ["dim1"] - metrics = ["count1"] - granularity = "all" - # get the counts of the top 5 'dim1's, order by 'sum1' - ds.run_query( - metrics, - granularity, - from_dttm, - to_dttm, - groupby=columns, - timeseries_limit=5, - timeseries_limit_metric="sum1", - client=client, - order_desc=True, - filter=[], - ) - qry_obj = client.topn.call_args_list[0][1] - self.assertEqual("dim1", qry_obj["dimension"]) - self.assertEqual("sum1", qry_obj["metric"]) - aggregations = qry_obj["aggregations"] - post_aggregations = qry_obj["post_aggregations"] - self.assertEqual({"count1", "sum1"}, set(aggregations.keys())) - self.assertEqual(set(), set(post_aggregations.keys())) - - # get the counts of the top 5 'dim1's, order by 'div1' - ds.run_query( - metrics, - granularity, - from_dttm, - to_dttm, - groupby=columns, - timeseries_limit=5, - timeseries_limit_metric="div1", - client=client, - order_desc=True, - filter=[], - ) - qry_obj = client.topn.call_args_list[1][1] - self.assertEqual("dim1", qry_obj["dimension"]) - self.assertEqual("div1", qry_obj["metric"]) - aggregations = qry_obj["aggregations"] - post_aggregations = qry_obj["post_aggregations"] - self.assertEqual({"count1", "sum1", "sum2"}, set(aggregations.keys())) - self.assertEqual({"div1"}, set(post_aggregations.keys())) - - columns = ["dim1", "dim2"] - # get the counts of the top 5 ['dim1', 'dim2']s, order by 'sum1' - ds.run_query( - metrics, - granularity, - from_dttm, - to_dttm, - groupby=columns, - timeseries_limit=5, - timeseries_limit_metric="sum1", - client=client, - order_desc=True, - filter=[], - ) - qry_obj = client.groupby.call_args_list[0][1] - self.assertEqual({"dim1", "dim2"}, set(qry_obj["dimensions"])) - self.assertEqual("sum1", qry_obj["limit_spec"]["columns"][0]["dimension"]) - aggregations = qry_obj["aggregations"] - post_aggregations = qry_obj["post_aggregations"] - self.assertEqual({"count1", "sum1"}, set(aggregations.keys())) - self.assertEqual(set(), set(post_aggregations.keys())) - - # get the counts of the top 5 ['dim1', 'dim2']s, order by 'div1' - ds.run_query( - metrics, - granularity, - from_dttm, - to_dttm, - groupby=columns, - timeseries_limit=5, - timeseries_limit_metric="div1", - client=client, - order_desc=True, - filter=[], - ) - qry_obj = client.groupby.call_args_list[1][1] - self.assertEqual({"dim1", "dim2"}, set(qry_obj["dimensions"])) - self.assertEqual("div1", qry_obj["limit_spec"]["columns"][0]["dimension"]) - aggregations = qry_obj["aggregations"] - post_aggregations = qry_obj["post_aggregations"] - self.assertEqual({"count1", "sum1", "sum2"}, set(aggregations.keys())) - self.assertEqual({"div1"}, set(post_aggregations.keys())) - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_get_aggregations(self): - ds = DruidDatasource(datasource_name="datasource") - metrics_dict = { - "sum1": DruidMetric( - metric_name="sum1", - metric_type="doubleSum", - json=json.dumps({"type": "doubleSum", "name": "sum1"}), - ), - "sum2": DruidMetric( - metric_name="sum2", - metric_type="doubleSum", - json=json.dumps({"type": "doubleSum", "name": "sum2"}), - ), - "div1": DruidMetric( - metric_name="div1", - metric_type="postagg", - json=json.dumps( - { - "fn": "/", - "type": "arithmetic", - "name": "div1", - "fields": [ - {"fieldName": "sum1", "type": "fieldAccess"}, - {"fieldName": "sum2", "type": "fieldAccess"}, - ], - } - ), - ), - } - metric_names = ["sum1", "sum2"] - aggs = ds.get_aggregations(metrics_dict, metric_names) - expected_agg = {name: metrics_dict[name].json_obj for name in metric_names} - self.assertEqual(expected_agg, aggs) - - metric_names = ["sum1", "col1"] - self.assertRaises( - SupersetException, ds.get_aggregations, metrics_dict, metric_names - ) - - metric_names = ["sum1", "div1"] - self.assertRaises( - SupersetException, ds.get_aggregations, metrics_dict, metric_names - ) diff --git a/tests/integration_tests/druid_tests.py b/tests/integration_tests/druid_tests.py deleted file mode 100644 index 66f5cc7244fc..000000000000 --- a/tests/integration_tests/druid_tests.py +++ /dev/null @@ -1,668 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# isort:skip_file -"""Unit tests for Superset""" -import json -import unittest -from datetime import datetime -from unittest.mock import Mock, patch - -from tests.integration_tests.test_app import app - -from superset import db, security_manager -from superset.connectors.druid.views import ( - Druid, - DruidClusterModelView, - DruidColumnInlineView, - DruidDatasourceModelView, - DruidMetricInlineView, -) - -from .base_tests import SupersetTestCase - - -try: - from superset.connectors.druid.models import ( - DruidCluster, - DruidColumn, - DruidDatasource, - DruidMetric, - ) -except ImportError: - pass - - -class PickableMock(Mock): - def __reduce__(self): - return (Mock, ()) - - -SEGMENT_METADATA = [ - { - "id": "some_id", - "intervals": ["2013-05-13T00:00:00.000Z/2013-05-14T00:00:00.000Z"], - "columns": { - "__time": { - "type": "LONG", - "hasMultipleValues": False, - "size": 407240380, - "cardinality": None, - "errorMessage": None, - }, - "dim1": { - "type": "STRING", - "hasMultipleValues": False, - "size": 100000, - "cardinality": 1944, - "errorMessage": None, - }, - "dim2": { - "type": "STRING", - "hasMultipleValues": True, - "size": 100000, - "cardinality": 1504, - "errorMessage": None, - }, - "metric1": { - "type": "FLOAT", - "hasMultipleValues": False, - "size": 100000, - "cardinality": None, - "errorMessage": None, - }, - }, - "aggregators": { - "metric1": {"type": "longSum", "name": "metric1", "fieldName": "metric1"} - }, - "size": 300000, - "numRows": 5000000, - } -] - -GB_RESULT_SET = [ - { - "version": "v1", - "timestamp": "2012-01-01T00:00:00.000Z", - "event": {"dim1": "Canada", "dim2": "boy", "count": 12345678}, - }, - { - "version": "v1", - "timestamp": "2012-01-01T00:00:00.000Z", - "event": {"dim1": "USA", "dim2": "girl", "count": 12345678 / 2}, - }, -] - -DruidCluster.get_druid_version = lambda _: "0.9.1" # type: ignore - - -class TestDruid(SupersetTestCase): - - """Testing interactions with Druid""" - - @classmethod - def setUpClass(cls): - cls.create_druid_test_objects() - - def get_test_cluster_obj(self): - return DruidCluster( - cluster_name="test_cluster", - broker_host="localhost", - broker_port=7980, - broker_endpoint="druid/v2", - metadata_last_refreshed=datetime.now(), - ) - - def get_cluster(self, PyDruid): - instance = PyDruid.return_value - instance.time_boundary.return_value = [{"result": {"maxTime": "2016-01-01"}}] - instance.segment_metadata.return_value = SEGMENT_METADATA - - cluster = ( - db.session.query(DruidCluster) - .filter_by(cluster_name="test_cluster") - .first() - ) - if cluster: - for datasource in ( - db.session.query(DruidDatasource).filter_by(cluster_id=cluster.id).all() - ): - db.session.delete(datasource) - - db.session.delete(cluster) - db.session.commit() - - cluster = self.get_test_cluster_obj() - - db.session.add(cluster) - cluster.get_datasources = PickableMock(return_value=["test_datasource"]) - - return cluster - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - @patch("superset.connectors.druid.models.PyDruid") - def test_client(self, PyDruid): - self.login(username="admin") - cluster = self.get_cluster(PyDruid) - cluster.refresh_datasources() - cluster.refresh_datasources(merge_flag=True) - datasource_id = cluster.datasources[0].id - db.session.commit() - - nres = [ - list(v["event"].items()) + [("timestamp", v["timestamp"])] - for v in GB_RESULT_SET - ] - nres = [dict(v) for v in nres] - import pandas as pd - - df = pd.DataFrame(nres) - instance = PyDruid.return_value - instance.export_pandas.return_value = df - instance.query_dict = {} - instance.query_builder.last_query.query_dict = {} - - resp = self.get_resp("/superset/explore/druid/{}/".format(datasource_id)) - self.assertIn("test_datasource", resp) - form_data = { - "viz_type": "table", - "granularity": "one+day", - "druid_time_origin": "", - "since": "7 days ago", - "until": "now", - "row_limit": 5000, - "include_search": "false", - "metrics": ["count"], - "groupby": ["dim1"], - "force": "true", - } - # One groupby - url = "/superset/explore_json/druid/{}/".format(datasource_id) - resp = self.get_json_resp(url, {"form_data": json.dumps(form_data)}) - self.assertEqual("Canada", resp["data"]["records"][0]["dim1"]) - - form_data = { - "viz_type": "table", - "granularity": "one+day", - "druid_time_origin": "", - "since": "7 days ago", - "until": "now", - "row_limit": 5000, - "include_search": "false", - "metrics": ["count"], - "groupby": ["dim1", "dim2"], - "force": "true", - } - # two groupby - url = "/superset/explore_json/druid/{}/".format(datasource_id) - resp = self.get_json_resp(url, {"form_data": json.dumps(form_data)}) - self.assertEqual("Canada", resp["data"]["records"][0]["dim1"]) - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_druid_sync_from_config(self): - CLUSTER_NAME = "new_druid" - self.login() - cluster = self.get_or_create( - DruidCluster, {"cluster_name": CLUSTER_NAME}, db.session - ) - - db.session.merge(cluster) - db.session.commit() - - ds = ( - db.session.query(DruidDatasource) - .filter_by(datasource_name="test_click") - .first() - ) - if ds: - db.session.delete(ds) - db.session.commit() - - cfg = { - "user": "admin", - "cluster": CLUSTER_NAME, - "config": { - "name": "test_click", - "dimensions": ["affiliate_id", "campaign", "first_seen"], - "metrics_spec": [ - {"type": "count", "name": "count"}, - {"type": "sum", "name": "sum"}, - ], - "batch_ingestion": { - "sql": "SELECT * FROM clicks WHERE d='{{ ds }}'", - "ts_column": "d", - "sources": [{"table": "clicks", "partition": "d='{{ ds }}'"}], - }, - }, - } - - def check(): - resp = self.client.post("/superset/sync_druid/", data=json.dumps(cfg)) - druid_ds = ( - db.session.query(DruidDatasource) - .filter_by(datasource_name="test_click") - .one() - ) - col_names = set([c.column_name for c in druid_ds.columns]) - assert {"affiliate_id", "campaign", "first_seen"} == col_names - metric_names = {m.metric_name for m in druid_ds.metrics} - assert {"count", "sum"} == metric_names - assert resp.status_code == 201 - - check() - # checking twice to make sure a second sync yields the same results - check() - - # datasource exists, add new metrics and dimensions - cfg = { - "user": "admin", - "cluster": CLUSTER_NAME, - "config": { - "name": "test_click", - "dimensions": ["affiliate_id", "second_seen"], - "metrics_spec": [ - {"type": "bla", "name": "sum"}, - {"type": "unique", "name": "unique"}, - ], - }, - } - resp = self.client.post("/superset/sync_druid/", data=json.dumps(cfg)) - druid_ds = ( - db.session.query(DruidDatasource) - .filter_by(datasource_name="test_click") - .one() - ) - # columns and metrics are not deleted if config is changed as - # user could define their own dimensions / metrics and want to keep them - assert set([c.column_name for c in druid_ds.columns]) == set( - ["affiliate_id", "campaign", "first_seen", "second_seen"] - ) - assert set([m.metric_name for m in druid_ds.metrics]) == set( - ["count", "sum", "unique"] - ) - # metric type will not be overridden, sum stays instead of bla - assert set([m.metric_type for m in druid_ds.metrics]) == set( - ["longSum", "sum", "unique"] - ) - assert resp.status_code == 201 - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - @unittest.skipUnless(app.config["DRUID_IS_ACTIVE"], "DRUID_IS_ACTIVE is false") - def test_filter_druid_datasource(self): - CLUSTER_NAME = "new_druid" - cluster = self.get_or_create( - DruidCluster, {"cluster_name": CLUSTER_NAME}, db.session - ) - db.session.merge(cluster) - - gamma_ds = self.get_or_create( - DruidDatasource, - {"datasource_name": "datasource_for_gamma", "cluster": cluster}, - db.session, - ) - gamma_ds.cluster = cluster - db.session.merge(gamma_ds) - - no_gamma_ds = self.get_or_create( - DruidDatasource, - {"datasource_name": "datasource_not_for_gamma", "cluster": cluster}, - db.session, - ) - no_gamma_ds.cluster = cluster - db.session.merge(no_gamma_ds) - db.session.commit() - - security_manager.add_permission_view_menu("datasource_access", gamma_ds.perm) - security_manager.add_permission_view_menu("datasource_access", no_gamma_ds.perm) - - perm = security_manager.find_permission_view_menu( - "datasource_access", gamma_ds.get_perm() - ) - security_manager.add_permission_role(security_manager.find_role("Gamma"), perm) - security_manager.get_session.commit() - - self.login(username="gamma") - url = "/druiddatasourcemodelview/list/" - resp = self.get_resp(url) - self.assertIn("datasource_for_gamma", resp) - self.assertNotIn("datasource_not_for_gamma", resp) - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - @patch("superset.connectors.druid.models.PyDruid") - def test_sync_druid_perm(self, PyDruid): - self.login(username="admin") - instance = PyDruid.return_value - instance.time_boundary.return_value = [{"result": {"maxTime": "2016-01-01"}}] - instance.segment_metadata.return_value = SEGMENT_METADATA - - cluster = ( - db.session.query(DruidCluster) - .filter_by(cluster_name="test_cluster") - .first() - ) - if cluster: - for datasource in ( - db.session.query(DruidDatasource).filter_by(cluster_id=cluster.id).all() - ): - db.session.delete(datasource) - - db.session.delete(cluster) - db.session.commit() - - cluster = DruidCluster( - cluster_name="test_cluster", - broker_host="localhost", - broker_port=7980, - metadata_last_refreshed=datetime.now(), - ) - - db.session.add(cluster) - cluster.get_datasources = PickableMock(return_value=["test_datasource"]) - - cluster.refresh_datasources() - cluster.datasources[0].merge_flag = True - metadata = cluster.datasources[0].latest_metadata() - self.assertEqual(len(metadata), 4) - db.session.commit() - - view_menu_name = cluster.datasources[0].get_perm() - view_menu = security_manager.find_view_menu(view_menu_name) - permission = security_manager.find_permission("datasource_access") - - pv = ( - security_manager.get_session.query(security_manager.permissionview_model) - .filter_by(permission=permission, view_menu=view_menu) - .first() - ) - assert pv is not None - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - @patch("superset.connectors.druid.models.PyDruid") - def test_refresh_metadata(self, PyDruid): - self.login(username="admin") - cluster = self.get_cluster(PyDruid) - cluster.refresh_datasources() - datasource = cluster.datasources[0] - - cols = db.session.query(DruidColumn).filter( - DruidColumn.datasource_id == datasource.id - ) - - for col in cols: - self.assertIn(col.column_name, SEGMENT_METADATA[0]["columns"].keys()) - - metrics = ( - db.session.query(DruidMetric) - .filter(DruidMetric.datasource_id == datasource.id) - .filter(DruidMetric.metric_name.like("%__metric1")) - ) - - for metric in metrics: - agg, _ = metric.metric_name.split("__") - - self.assertEqual( - json.loads(metric.json)["type"], "double{}".format(agg.capitalize()) - ) - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - @patch("superset.connectors.druid.models.PyDruid") - def test_refresh_metadata_augment_type(self, PyDruid): - self.login(username="admin") - cluster = self.get_cluster(PyDruid) - cluster.refresh_datasources() - - metadata = SEGMENT_METADATA[:] - metadata[0]["columns"]["metric1"]["type"] = "LONG" - instance = PyDruid.return_value - instance.segment_metadata.return_value = metadata - cluster.refresh_datasources() - datasource = cluster.datasources[0] - - column = ( - db.session.query(DruidColumn) - .filter(DruidColumn.datasource_id == datasource.id) - .filter(DruidColumn.column_name == "metric1") - ).one() - - self.assertEqual(column.type, "LONG") - - metrics = ( - db.session.query(DruidMetric) - .filter(DruidMetric.datasource_id == datasource.id) - .filter(DruidMetric.metric_name.like("%__metric1")) - ) - - for metric in metrics: - agg, _ = metric.metric_name.split("__") - - self.assertEqual(metric.json_obj["type"], "long{}".format(agg.capitalize())) - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - @patch("superset.connectors.druid.models.PyDruid") - def test_refresh_metadata_augment_verbose_name(self, PyDruid): - self.login(username="admin") - cluster = self.get_cluster(PyDruid) - cluster.refresh_datasources() - datasource = cluster.datasources[0] - - metrics = ( - db.session.query(DruidMetric) - .filter(DruidMetric.datasource_id == datasource.id) - .filter(DruidMetric.metric_name.like("%__metric1")) - ) - - for metric in metrics: - metric.verbose_name = metric.metric_name - - db.session.commit() - - # The verbose name should not change during a refresh. - cluster.refresh_datasources() - datasource = cluster.datasources[0] - - metrics = ( - db.session.query(DruidMetric) - .filter(DruidMetric.datasource_id == datasource.id) - .filter(DruidMetric.metric_name.like("%__metric1")) - ) - - for metric in metrics: - self.assertEqual(metric.verbose_name, metric.metric_name) - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - def test_urls(self): - cluster = self.get_test_cluster_obj() - self.assertEqual( - cluster.get_base_url("localhost", "9999"), "http://localhost:9999" - ) - self.assertEqual( - cluster.get_base_url("http://localhost", "9999"), "http://localhost:9999" - ) - self.assertEqual( - cluster.get_base_url("https://localhost", "9999"), "https://localhost:9999" - ) - - self.assertEqual( - cluster.get_base_broker_url(), "http://localhost:7980/druid/v2" - ) - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - @patch("superset.connectors.druid.models.PyDruid") - def test_druid_time_granularities(self, PyDruid): - self.login(username="admin") - cluster = self.get_cluster(PyDruid) - cluster.refresh_datasources() - cluster.refresh_datasources(merge_flag=True) - datasource_id = cluster.datasources[0].id - db.session.commit() - - nres = [ - list(v["event"].items()) + [("timestamp", v["timestamp"])] - for v in GB_RESULT_SET - ] - nres = [dict(v) for v in nres] - import pandas as pd - - df = pd.DataFrame(nres) - instance = PyDruid.return_value - instance.export_pandas.return_value = df - instance.query_dict = {} - instance.query_builder.last_query.query_dict = {} - - form_data = { - "viz_type": "table", - "since": "7 days ago", - "until": "now", - "metrics": ["count"], - "groupby": [], - "include_time": "true", - } - - granularity_map = { - "5 seconds": "PT5S", - "30 seconds": "PT30S", - "1 minute": "PT1M", - "5 minutes": "PT5M", - "1 hour": "PT1H", - "6 hour": "PT6H", - "one day": "P1D", - "1 day": "P1D", - "7 days": "P7D", - "week": "P1W", - "week_starting_sunday": "P1W", - "week_ending_saturday": "P1W", - "month": "P1M", - "quarter": "P3M", - "year": "P1Y", - } - url = "/superset/explore_json/druid/{}/".format(datasource_id) - - for granularity_mapping in granularity_map: - form_data["granularity"] = granularity_mapping - self.get_json_resp(url, {"form_data": json.dumps(form_data)}) - self.assertEqual( - granularity_map[granularity_mapping], - instance.timeseries.call_args[1]["granularity"]["period"], - ) - - @unittest.skipUnless( - SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" - ) - @patch("superset.connectors.druid.models.PyDruid") - def test_external_metadata(self, PyDruid): - self.login(username="admin") - self.login(username="admin") - cluster = self.get_cluster(PyDruid) - cluster.refresh_datasources() - datasource = cluster.datasources[0] - url = "/datasource/external_metadata/druid/{}/".format(datasource.id) - resp = self.get_json_resp(url) - col_names = {o.get("name") for o in resp} - self.assertEqual(col_names, {"__time", "dim1", "dim2", "metric1"}) - - -class TestDruidViewEnabling(SupersetTestCase): - def test_druid_disabled(self): - with patch.object(Druid, "is_enabled", return_value=False): - self.login("admin") - uri = "/druid/refresh_datasources/" - rv = self.client.get(uri) - self.assertEqual(rv.status_code, 404) - - def test_druid_enabled(self): - with patch.object(Druid, "is_enabled", return_value=True): - self.login("admin") - uri = "/druid/refresh_datasources/" - rv = self.client.get(uri) - self.assertLess(rv.status_code, 400) - - def test_druid_cluster_disabled(self): - with patch.object(DruidClusterModelView, "is_enabled", return_value=False): - self.login("admin") - uri = "/druidclustermodelview/list/" - rv = self.client.get(uri) - self.assertEqual(rv.status_code, 404) - - def test_druid_cluster_enabled(self): - with patch.object(DruidClusterModelView, "is_enabled", return_value=True): - self.login("admin") - uri = "/druidclustermodelview/list/" - rv = self.client.get(uri) - self.assertLess(rv.status_code, 400) - - def test_druid_column_disabled(self): - with patch.object(DruidColumnInlineView, "is_enabled", return_value=False): - self.login("admin") - uri = "/druidcolumninlineview/list/" - rv = self.client.get(uri) - self.assertEqual(rv.status_code, 404) - - def test_druid_column_enabled(self): - with patch.object(DruidColumnInlineView, "is_enabled", return_value=True): - self.login("admin") - uri = "/druidcolumninlineview/list/" - rv = self.client.get(uri) - self.assertLess(rv.status_code, 400) - - def test_druid_datasource_disabled(self): - with patch.object(DruidDatasourceModelView, "is_enabled", return_value=False): - self.login("admin") - uri = "/druiddatasourcemodelview/list/" - rv = self.client.get(uri) - self.assertEqual(rv.status_code, 404) - - def test_druid_datasource_enabled(self): - with patch.object(DruidDatasourceModelView, "is_enabled", return_value=True): - self.login("admin") - uri = "/druiddatasourcemodelview/list/" - rv = self.client.get(uri) - self.assertLess(rv.status_code, 400) - - def test_druid_metric_disabled(self): - with patch.object(DruidMetricInlineView, "is_enabled", return_value=False): - self.login("admin") - uri = "/druidmetricinlineview/list/" - rv = self.client.get(uri) - self.assertEqual(rv.status_code, 404) - - def test_druid_metric_enabled(self): - with patch.object(DruidMetricInlineView, "is_enabled", return_value=True): - self.login("admin") - uri = "/druidmetricinlineview/list/" - rv = self.client.get(uri) - self.assertLess(rv.status_code, 400) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/integration_tests/import_export_tests.py b/tests/integration_tests/import_export_tests.py index 67d2a89d866f..6d7d581ec6d4 100644 --- a/tests/integration_tests/import_export_tests.py +++ b/tests/integration_tests/import_export_tests.py @@ -34,12 +34,7 @@ from tests.integration_tests.test_app import app from superset.dashboards.commands.importers.v0 import decode_dashboards from superset import db, security_manager -from superset.connectors.druid.models import ( - DruidColumn, - DruidDatasource, - DruidMetric, - DruidCluster, -) + from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn from superset.dashboards.commands.importers.v0 import import_chart, import_dashboard from superset.datasets.commands.importers.v0 import import_dataset @@ -72,15 +67,11 @@ def delete_imports(cls): for table in session.query(SqlaTable): if "remote_id" in table.params_dict: session.delete(table) - for datasource in session.query(DruidDatasource): - if "remote_id" in datasource.params_dict: - session.delete(datasource) session.commit() @classmethod def setUpClass(cls): cls.delete_imports() - cls.create_druid_test_objects() @classmethod def tearDownClass(cls): @@ -141,25 +132,6 @@ def create_table(self, name, schema=None, id=0, cols_names=[], metric_names=[]): table.metrics.append(SqlMetric(metric_name=metric_name, expression="")) return table - def create_druid_datasource(self, name, id=0, cols_names=[], metric_names=[]): - cluster_name = "druid_test" - cluster = self.get_or_create( - DruidCluster, {"cluster_name": cluster_name}, db.session - ) - - params = {"remote_id": id, "database_name": cluster_name} - datasource = DruidDatasource( - id=id, - datasource_name=name, - cluster_id=cluster.id, - params=json.dumps(params), - ) - for col_name in cols_names: - datasource.columns.append(DruidColumn(column_name=col_name)) - for metric_name in metric_names: - datasource.metrics.append(DruidMetric(metric_name=metric_name, json="{}")) - return datasource - def get_slice(self, slc_id): return db.session.query(Slice).filter_by(id=slc_id).first() @@ -169,9 +141,6 @@ def get_slice_by_name(self, name): def get_dash(self, dash_id): return db.session.query(Dashboard).filter_by(id=dash_id).first() - def get_datasource(self, datasource_id): - return db.session.query(DruidDatasource).filter_by(id=datasource_id).first() - def assert_dash_equals( self, expected_dash, actual_dash, check_position=True, check_slugs=True ): @@ -704,78 +673,6 @@ def test_import_table_override_identical(self): self.assertEqual(imported_id, imported_id_copy) self.assert_table_equals(copy_table, self.get_table_by_id(imported_id)) - def test_import_druid_no_metadata(self): - datasource = self.create_druid_datasource("pure_druid", id=10001) - imported_id = import_dataset(datasource, import_time=1989) - imported = self.get_datasource(imported_id) - self.assert_datasource_equals(datasource, imported) - - def test_import_druid_1_col_1_met(self): - datasource = self.create_druid_datasource( - "druid_1_col_1_met", id=10002, cols_names=["col1"], metric_names=["metric1"] - ) - imported_id = import_dataset(datasource, import_time=1990) - imported = self.get_datasource(imported_id) - self.assert_datasource_equals(datasource, imported) - self.assertEqual( - {"remote_id": 10002, "import_time": 1990, "database_name": "druid_test"}, - json.loads(imported.params), - ) - - def test_import_druid_2_col_2_met(self): - datasource = self.create_druid_datasource( - "druid_2_col_2_met", - id=10003, - cols_names=["c1", "c2"], - metric_names=["m1", "m2"], - ) - imported_id = import_dataset(datasource, import_time=1991) - imported = self.get_datasource(imported_id) - self.assert_datasource_equals(datasource, imported) - - def test_import_druid_override(self): - datasource = self.create_druid_datasource( - "druid_override", id=10004, cols_names=["col1"], metric_names=["m1"] - ) - imported_id = import_dataset(datasource, import_time=1991) - table_over = self.create_druid_datasource( - "druid_override", - id=10004, - cols_names=["new_col1", "col2", "col3"], - metric_names=["new_metric1"], - ) - imported_over_id = import_dataset(table_over, import_time=1992) - - imported_over = self.get_datasource(imported_over_id) - self.assertEqual(imported_id, imported_over.id) - expected_datasource = self.create_druid_datasource( - "druid_override", - id=10004, - metric_names=["new_metric1", "m1"], - cols_names=["col1", "new_col1", "col2", "col3"], - ) - self.assert_datasource_equals(expected_datasource, imported_over) - - def test_import_druid_override_identical(self): - datasource = self.create_druid_datasource( - "copy_cat", - id=10005, - cols_names=["new_col1", "col2", "col3"], - metric_names=["new_metric1"], - ) - imported_id = import_dataset(datasource, import_time=1993) - - copy_datasource = self.create_druid_datasource( - "copy_cat", - id=10005, - cols_names=["new_col1", "col2", "col3"], - metric_names=["new_metric1"], - ) - imported_id_copy = import_dataset(copy_datasource, import_time=1994) - - self.assertEqual(imported_id, imported_id_copy) - self.assert_datasource_equals(copy_datasource, self.get_datasource(imported_id)) - if __name__ == "__main__": unittest.main() diff --git a/tests/integration_tests/security_tests.py b/tests/integration_tests/security_tests.py index 3add863de839..1b6d1318db80 100644 --- a/tests/integration_tests/security_tests.py +++ b/tests/integration_tests/security_tests.py @@ -32,7 +32,6 @@ from superset.models.dashboard import Dashboard from superset import app, appbuilder, db, security_manager, viz, ConnectorRegistry -from superset.connectors.druid.models import DruidCluster, DruidDatasource from superset.connectors.sqla.models import SqlaTable from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import SupersetSecurityException @@ -273,93 +272,6 @@ def test_set_perm_sqla_table(self): session.delete(stored_table) session.commit() - @pytest.mark.usefixtures("load_world_bank_dashboard_with_slices") - def test_set_perm_druid_datasource(self): - self.create_druid_test_objects() - session = db.session - druid_cluster = ( - session.query(DruidCluster).filter_by(cluster_name="druid_test").one() - ) - datasource = DruidDatasource( - datasource_name="tmp_datasource", - cluster=druid_cluster, - cluster_id=druid_cluster.id, - ) - session.add(datasource) - session.commit() - - # store without a schema - stored_datasource = ( - session.query(DruidDatasource) - .filter_by(datasource_name="tmp_datasource") - .one() - ) - self.assertEqual( - stored_datasource.perm, - f"[druid_test].[tmp_datasource](id:{stored_datasource.id})", - ) - self.assertIsNotNone( - security_manager.find_permission_view_menu( - "datasource_access", stored_datasource.perm - ) - ) - self.assertIsNone(stored_datasource.schema_perm) - - # store with a schema - stored_datasource.datasource_name = "tmp_schema.tmp_datasource" - session.commit() - self.assertEqual( - stored_datasource.perm, - f"[druid_test].[tmp_schema.tmp_datasource](id:{stored_datasource.id})", - ) - self.assertIsNotNone( - security_manager.find_permission_view_menu( - "datasource_access", stored_datasource.perm - ) - ) - self.assertIsNotNone(stored_datasource.schema_perm, "[druid_test].[tmp_schema]") - self.assertIsNotNone( - security_manager.find_permission_view_menu( - "schema_access", stored_datasource.schema_perm - ) - ) - - session.delete(stored_datasource) - session.commit() - - def test_set_perm_druid_cluster(self): - session = db.session - cluster = DruidCluster(cluster_name="tmp_druid_cluster") - session.add(cluster) - - stored_cluster = ( - session.query(DruidCluster) - .filter_by(cluster_name="tmp_druid_cluster") - .one() - ) - self.assertEqual( - stored_cluster.perm, f"[tmp_druid_cluster].(id:{stored_cluster.id})" - ) - self.assertIsNotNone( - security_manager.find_permission_view_menu( - "database_access", stored_cluster.perm - ) - ) - - stored_cluster.cluster_name = "tmp_druid_cluster2" - session.commit() - self.assertEqual( - stored_cluster.perm, f"[tmp_druid_cluster2].(id:{stored_cluster.id})" - ) - self.assertIsNotNone( - security_manager.find_permission_view_menu( - "database_access", stored_cluster.perm - ) - ) - - session.delete(stored_cluster) - session.commit() - def test_set_perm_database(self): session = db.session database = Database(database_name="tmp_database", sqlalchemy_uri="sqlite://") @@ -390,28 +302,6 @@ def test_set_perm_database(self): session.delete(stored_db) session.commit() - def test_hybrid_perm_druid_cluster(self): - cluster = DruidCluster(cluster_name="tmp_druid_cluster3") - db.session.add(cluster) - - id_ = ( - db.session.query(DruidCluster.id) - .filter_by(cluster_name="tmp_druid_cluster3") - .scalar() - ) - - record = ( - db.session.query(DruidCluster) - .filter_by(perm=f"[tmp_druid_cluster3].(id:{id_})") - .one() - ) - - self.assertEqual(record.get_perm(), record.perm) - self.assertEqual(record.id, id_) - self.assertEqual(record.cluster_name, "tmp_druid_cluster3") - db.session.delete(cluster) - db.session.commit() - def test_hybrid_perm_database(self): database = Database(database_name="tmp_database3", sqlalchemy_uri="sqlite://") @@ -706,7 +596,6 @@ def assert_can_admin(self, perm_set): self.assertIn(("all_database_access", "all_database_access"), perm_set) self.assertIn(("can_override_role_permissions", "Superset"), perm_set) - self.assertIn(("can_sync_druid_source", "Superset"), perm_set) self.assertIn(("can_override_role_permissions", "Superset"), perm_set) self.assertIn(("can_approve", "Superset"), perm_set)