diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 128ce511be68..63ee7d7f9dcd 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -72,7 +72,6 @@ from superset.utils.network import is_hostname_valid, is_port_open if TYPE_CHECKING: - # prevent circular imports from superset.connectors.sqla.models import TableColumn from superset.models.core import Database from superset.models.sql_lab import Query diff --git a/superset/db_engine_specs/clickhouse.py b/superset/db_engine_specs/clickhouse.py index 930aeee52839..3a95771bcc31 100644 --- a/superset/db_engine_specs/clickhouse.py +++ b/superset/db_engine_specs/clickhouse.py @@ -14,29 +14,43 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import logging +import re from datetime import datetime -from typing import Any, Dict, List, Optional, Type, TYPE_CHECKING +from typing import Any, cast, Dict, List, Optional, Type, TYPE_CHECKING +from flask import current_app +from flask_babel import gettext as __ +from marshmallow import fields, Schema +from marshmallow.validate import Range from sqlalchemy import types +from sqlalchemy.engine.url import URL from urllib3.exceptions import NewConnectionError -from superset.db_engine_specs.base import BaseEngineSpec +from superset.databases.utils import make_url_safe +from superset.db_engine_specs.base import ( + BaseEngineSpec, + BasicParametersMixin, + BasicParametersType, + BasicPropertiesType, +) from superset.db_engine_specs.exceptions import SupersetDBAPIDatabaseError +from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.extensions import cache_manager +from superset.utils.core import GenericDataType +from superset.utils.hashing import md5_sha_from_str +from superset.utils.network import is_hostname_valid, is_port_open if TYPE_CHECKING: - # prevent circular imports from superset.models.core import Database logger = logging.getLogger(__name__) -class ClickHouseEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method - """Dialect for ClickHouse analytical DB.""" - - engine = "clickhouse" - engine_name = "ClickHouse" +class ClickHouseBaseEngineSpec(BaseEngineSpec): + """Shared engine spec for ClickHouse.""" time_secondary_columns = True time_groupby_inline = True @@ -56,8 +70,78 @@ class ClickHouseEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method "P1Y": "toStartOfYear(toDateTime({col}))", } - _show_functions_column = "name" + column_type_mappings = ( + ( + re.compile(r".*Enum.*", re.IGNORECASE), + types.String(), + GenericDataType.STRING, + ), + ( + re.compile(r".*Array.*", re.IGNORECASE), + types.String(), + GenericDataType.STRING, + ), + ( + re.compile(r".*UUID.*", re.IGNORECASE), + types.String(), + GenericDataType.STRING, + ), + ( + re.compile(r".*Bool.*", re.IGNORECASE), + types.Boolean(), + GenericDataType.BOOLEAN, + ), + ( + re.compile(r".*String.*", re.IGNORECASE), + types.String(), + GenericDataType.STRING, + ), + ( + re.compile(r".*Int\d+.*", re.IGNORECASE), + types.INTEGER(), + GenericDataType.NUMERIC, + ), + ( + re.compile(r".*Decimal.*", re.IGNORECASE), + types.DECIMAL(), + GenericDataType.NUMERIC, + ), + ( + re.compile(r".*DateTime.*", re.IGNORECASE), + types.DateTime(), + GenericDataType.TEMPORAL, + ), + ( + re.compile(r".*Date.*", re.IGNORECASE), + types.Date(), + GenericDataType.TEMPORAL, + ), + ) + + @classmethod + def epoch_to_dttm(cls) -> str: + return "{col}" + + @classmethod + def convert_dttm( + cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None + ) -> Optional[str]: + sqla_type = cls.get_sqla_column_type(target_type) + + if isinstance(sqla_type, types.Date): + return f"toDate('{dttm.date().isoformat()}')" + if isinstance(sqla_type, types.DateTime): + return f"""toDateTime('{dttm.isoformat(sep=" ", timespec="seconds")}')""" + return None + +class ClickHouseEngineSpec(ClickHouseBaseEngineSpec): + """Engine spec for clickhouse_sqlalchemy connector""" + + engine = "clickhouse" + engine_name = "ClickHouse" + + _show_functions_column = "name" supports_file_upload = False @classmethod @@ -73,21 +157,9 @@ def get_dbapi_mapped_exception(cls, exception: Exception) -> Exception: return exception return new_exception(str(exception)) - @classmethod - def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None - ) -> Optional[str]: - sqla_type = cls.get_sqla_column_type(target_type) - - if isinstance(sqla_type, types.Date): - return f"toDate('{dttm.date().isoformat()}')" - if isinstance(sqla_type, types.DateTime): - return f"""toDateTime('{dttm.isoformat(sep=" ", timespec="seconds")}')""" - return None - @classmethod @cache_manager.cache.memoize() - def get_function_names(cls, database: "Database") -> List[str]: + def get_function_names(cls, database: Database) -> List[str]: """ Get a list of function names that are able to be called on the database. Used for SQL Lab autocomplete. @@ -123,3 +195,201 @@ def get_function_names(cls, database: "Database") -> List[str]: # otherwise, return no function names to prevent errors return [] + + +class ClickHouseParametersSchema(Schema): + username = fields.String(allow_none=True, description=__("Username")) + password = fields.String(allow_none=True, description=__("Password")) + host = fields.String(required=True, description=__("Hostname or IP address")) + port = fields.Integer( + allow_none=True, + description=__("Database port"), + validate=Range(min=0, max=65535), + ) + database = fields.String(allow_none=True, description=__("Database name")) + encryption = fields.Boolean( + default=True, description=__("Use an encrypted connection to the database") + ) + query = fields.Dict( + keys=fields.Str(), values=fields.Raw(), description=__("Additional parameters") + ) + + +try: + from clickhouse_connect.common import set_setting + from clickhouse_connect.datatypes.format import set_default_formats + + # override default formats for compatibility + set_default_formats( + "FixedString", + "string", + "IPv*", + "string", + "signed", + "UUID", + "string", + "*Int256", + "string", + "*Int128", + "string", + ) + set_setting( + "product_name", + f"superset/{current_app.config.get('VERSION_STRING', 'dev')}", + ) +except ImportError: # ClickHouse Connect not installed, do nothing + pass + + +class ClickHouseConnectEngineSpec(ClickHouseEngineSpec, BasicParametersMixin): + """Engine spec for clickhouse-connect connector""" + + engine = "clickhousedb" + engine_name = "ClickHouse Connect" + + default_driver = "connect" + _function_names: List[str] = [] + + sqlalchemy_uri_placeholder = ( + "clickhousedb://user:password@host[:port][/dbname][?secure=value&=value...]" + ) + parameters_schema = ClickHouseParametersSchema() + encryption_parameters = {"secure": "true"} + + @classmethod + def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]: + return {} + + @classmethod + def get_dbapi_mapped_exception(cls, exception: Exception) -> Exception: + new_exception = cls.get_dbapi_exception_mapping().get(type(exception)) + if new_exception == SupersetDBAPIDatabaseError: + return SupersetDBAPIDatabaseError("Connection failed") + if not new_exception: + return exception + return new_exception(str(exception)) + + @classmethod + def get_function_names(cls, database: Database) -> List[str]: + # pylint: disable=import-outside-toplevel,import-error + from clickhouse_connect.driver.exceptions import ClickHouseError + + if cls._function_names: + return cls._function_names + try: + names = database.get_df( + "SELECT name FROM system.functions UNION ALL " + + "SELECT name FROM system.table_functions LIMIT 10000" + )["name"].tolist() + cls._function_names = names + return names + except ClickHouseError: + logger.exception("Error retrieving system.functions") + return [] + + @classmethod + def get_datatype(cls, type_code: str) -> str: + # keep it lowercase, as ClickHouse types aren't typical SHOUTCASE ANSI SQL + return type_code + + @classmethod + def build_sqlalchemy_uri( + cls, + parameters: BasicParametersType, + encrypted_extra: Optional[Dict[str, str]] = None, + ) -> str: + url_params = parameters.copy() + if url_params.get("encryption"): + query = parameters.get("query", {}).copy() + query.update(cls.encryption_parameters) + url_params["query"] = query + if not url_params.get("database"): + url_params["database"] = "__default__" + url_params.pop("encryption", None) + return str(URL(f"{cls.engine}+{cls.default_driver}", **url_params)) + + @classmethod + def get_parameters_from_uri( + cls, uri: str, encrypted_extra: Optional[Dict[str, Any]] = None + ) -> BasicParametersType: + url = make_url_safe(uri) + query = url.query + if "secure" in query: + encryption = url.query.get("secure") == "true" + query.pop("secure") + else: + encryption = False + return BasicParametersType( + username=url.username, + password=url.password, + host=url.host, + port=url.port, + database="" if url.database == "__default__" else cast(str, url.database), + query=dict(query), + encryption=encryption, + ) + + @classmethod + def validate_parameters( + cls, properties: BasicPropertiesType + ) -> List[SupersetError]: + # pylint: disable=import-outside-toplevel,import-error + from clickhouse_connect.driver import default_port + + parameters = properties.get("parameters", {}) + host = parameters.get("host", None) + if not host: + return [ + SupersetError( + "Hostname is required", + SupersetErrorType.CONNECTION_MISSING_PARAMETERS_ERROR, + ErrorLevel.WARNING, + {"missing": ["host"]}, + ) + ] + if not is_hostname_valid(host): + return [ + SupersetError( + "The hostname provided can't be resolved.", + SupersetErrorType.CONNECTION_INVALID_HOSTNAME_ERROR, + ErrorLevel.ERROR, + {"invalid": ["host"]}, + ) + ] + port = parameters.get("port") + if port is None: + port = default_port("http", parameters.get("encryption", False)) + try: + port = int(port) + except (ValueError, TypeError): + port = -1 + if port <= 0 or port >= 65535: + return [ + SupersetError( + "Port must be a valid integer between 0 and 65535 (inclusive).", + SupersetErrorType.CONNECTION_INVALID_PORT_ERROR, + ErrorLevel.ERROR, + {"invalid": ["port"]}, + ) + ] + if not is_port_open(host, port): + return [ + SupersetError( + "The port is closed.", + SupersetErrorType.CONNECTION_PORT_CLOSED_ERROR, + ErrorLevel.ERROR, + {"invalid": ["port"]}, + ) + ] + return [] + + @staticmethod + def _mutate_label(label: str) -> str: + """ + Suffix with the first six characters from the md5 of the label to avoid + collisions with original column names + + :param label: Expected expression label + :return: Conditionally mutated label + """ + return f"{label}_{md5_sha_from_str(label)[:6]}" diff --git a/tests/unit_tests/db_engine_specs/test_clickhouse.py b/tests/unit_tests/db_engine_specs/test_clickhouse.py index 9a52b04616af..0c437bc00998 100644 --- a/tests/unit_tests/db_engine_specs/test_clickhouse.py +++ b/tests/unit_tests/db_engine_specs/test_clickhouse.py @@ -16,12 +16,26 @@ # under the License. from datetime import datetime -from typing import Optional +from typing import Any, Dict, Optional, Type from unittest.mock import Mock import pytest +from sqlalchemy.types import ( + Boolean, + Date, + DateTime, + DECIMAL, + Float, + Integer, + String, + TypeEngine, +) -from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm +from superset.utils.core import GenericDataType +from tests.unit_tests.db_engine_specs.utils import ( + assert_column_spec, + assert_convert_dttm, +) from tests.unit_tests.fixtures.common import dttm @@ -53,3 +67,147 @@ def test_execute_connection_error() -> None: ) with pytest.raises(SupersetDBAPIDatabaseError) as ex: ClickHouseEngineSpec.execute(cursor, "SELECT col1 from table1") + + +@pytest.mark.parametrize( + "target_type,expected_result", + [ + ("Date", "toDate('2019-01-02')"), + ("DateTime", "toDateTime('2019-01-02 03:04:05')"), + ("UnknownType", None), + ], +) +def test_connect_convert_dttm( + target_type: str, expected_result: Optional[str], dttm: datetime +) -> None: + from superset.db_engine_specs.clickhouse import ClickHouseEngineSpec as spec + + assert_convert_dttm(spec, target_type, expected_result, dttm) + + +@pytest.mark.parametrize( + "native_type,sqla_type,attrs,generic_type,is_dttm", + [ + ("String", String, None, GenericDataType.STRING, False), + ("LowCardinality(String)", String, None, GenericDataType.STRING, False), + ("Nullable(String)", String, None, GenericDataType.STRING, False), + ( + "LowCardinality(Nullable(String))", + String, + None, + GenericDataType.STRING, + False, + ), + ("Array(UInt8)", String, None, GenericDataType.STRING, False), + ("Enum('hello', 'world')", String, None, GenericDataType.STRING, False), + ("Enum('UInt32', 'Bool')", String, None, GenericDataType.STRING, False), + ( + "LowCardinality(Enum('hello', 'world'))", + String, + None, + GenericDataType.STRING, + False, + ), + ( + "Nullable(Enum('hello', 'world'))", + String, + None, + GenericDataType.STRING, + False, + ), + ( + "LowCardinality(Nullable(Enum('hello', 'world')))", + String, + None, + GenericDataType.STRING, + False, + ), + ("FixedString(16)", String, None, GenericDataType.STRING, False), + ("Nullable(FixedString(16))", String, None, GenericDataType.STRING, False), + ( + "LowCardinality(Nullable(FixedString(16)))", + String, + None, + GenericDataType.STRING, + False, + ), + ("UUID", String, None, GenericDataType.STRING, False), + ("Int8", Integer, None, GenericDataType.NUMERIC, False), + ("Int16", Integer, None, GenericDataType.NUMERIC, False), + ("Int32", Integer, None, GenericDataType.NUMERIC, False), + ("Int64", Integer, None, GenericDataType.NUMERIC, False), + ("Int128", Integer, None, GenericDataType.NUMERIC, False), + ("Int256", Integer, None, GenericDataType.NUMERIC, False), + ("Nullable(Int256)", Integer, None, GenericDataType.NUMERIC, False), + ( + "LowCardinality(Nullable(Int256))", + Integer, + None, + GenericDataType.NUMERIC, + False, + ), + ("UInt8", Integer, None, GenericDataType.NUMERIC, False), + ("UInt16", Integer, None, GenericDataType.NUMERIC, False), + ("UInt32", Integer, None, GenericDataType.NUMERIC, False), + ("UInt64", Integer, None, GenericDataType.NUMERIC, False), + ("UInt128", Integer, None, GenericDataType.NUMERIC, False), + ("UInt256", Integer, None, GenericDataType.NUMERIC, False), + ("Nullable(UInt256)", Integer, None, GenericDataType.NUMERIC, False), + ( + "LowCardinality(Nullable(UInt256))", + Integer, + None, + GenericDataType.NUMERIC, + False, + ), + ("Float32", Float, None, GenericDataType.NUMERIC, False), + ("Float64", Float, None, GenericDataType.NUMERIC, False), + ("Decimal(1, 2)", DECIMAL, None, GenericDataType.NUMERIC, False), + ("Decimal32(2)", DECIMAL, None, GenericDataType.NUMERIC, False), + ("Decimal64(2)", DECIMAL, None, GenericDataType.NUMERIC, False), + ("Decimal128(2)", DECIMAL, None, GenericDataType.NUMERIC, False), + ("Decimal256(2)", DECIMAL, None, GenericDataType.NUMERIC, False), + ("Bool", Boolean, None, GenericDataType.BOOLEAN, False), + ("Nullable(Bool)", Boolean, None, GenericDataType.BOOLEAN, False), + ("Date", Date, None, GenericDataType.TEMPORAL, True), + ("Nullable(Date)", Date, None, GenericDataType.TEMPORAL, True), + ("LowCardinality(Nullable(Date))", Date, None, GenericDataType.TEMPORAL, True), + ("Date32", Date, None, GenericDataType.TEMPORAL, True), + ("Datetime", DateTime, None, GenericDataType.TEMPORAL, True), + ("Nullable(Datetime)", DateTime, None, GenericDataType.TEMPORAL, True), + ( + "LowCardinality(Nullable(Datetime))", + DateTime, + None, + GenericDataType.TEMPORAL, + True, + ), + ("Datetime('UTC')", DateTime, None, GenericDataType.TEMPORAL, True), + ("Datetime64(3)", DateTime, None, GenericDataType.TEMPORAL, True), + ("Datetime64(3, 'UTC')", DateTime, None, GenericDataType.TEMPORAL, True), + ], +) +def test_connect_get_column_spec( + native_type: str, + sqla_type: Type[TypeEngine], + attrs: Optional[Dict[str, Any]], + generic_type: GenericDataType, + is_dttm: bool, +) -> None: + from superset.db_engine_specs.clickhouse import ClickHouseConnectEngineSpec as spec + + assert_column_spec(spec, native_type, sqla_type, attrs, generic_type, is_dttm) + + +@pytest.mark.parametrize( + "column_name,expected_result", + [ + ("time", "time_07cc69"), + ("count", "count_e2942a"), + ], +) +def test_connect_make_label_compatible(column_name: str, expected_result: str) -> None: + from superset.db_engine_specs.clickhouse import ClickHouseConnectEngineSpec as spec + + label = spec.make_label_compatible(column_name) + assert label == expected_result diff --git a/tests/unit_tests/db_engine_specs/test_mssql.py b/tests/unit_tests/db_engine_specs/test_mssql.py index 63a315c14c87..554ad97055f6 100644 --- a/tests/unit_tests/db_engine_specs/test_mssql.py +++ b/tests/unit_tests/db_engine_specs/test_mssql.py @@ -20,7 +20,7 @@ from typing import Any, Dict, Optional, Type import pytest -from sqlalchemy import column, table, types +from sqlalchemy import column, table from sqlalchemy.dialects import mssql from sqlalchemy.dialects.mssql import DATE, NTEXT, NVARCHAR, TEXT, VARCHAR from sqlalchemy.sql import select @@ -50,7 +50,7 @@ ) def test_get_column_spec( native_type: str, - sqla_type: Type[types.TypeEngine], + sqla_type: Type[TypeEngine], attrs: Optional[Dict[str, Any]], generic_type: GenericDataType, is_dttm: bool,