Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(clickhouse): add clickhouse connect driver #23185

Merged
merged 4 commits into from
Feb 24, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@
from superset.utils.network import is_hostname_valid, is_port_open

if TYPE_CHECKING:
# prevent circular imports
from superset.connectors.sqla.models import TableColumn
from superset.models.core import Database
from superset.models.sql_lab import Query
Expand Down
301 changes: 279 additions & 22 deletions superset/db_engine_specs/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,29 +14,43 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import logging
import re
from datetime import datetime
from typing import Any, Dict, List, Optional, Type, TYPE_CHECKING
from typing import Any, cast, Dict, List, Optional, Type, TYPE_CHECKING

from flask import current_app
from flask_babel import gettext as __
from marshmallow import fields, Schema
from marshmallow.validate import Range
from sqlalchemy import types
from sqlalchemy.engine.url import URL
from urllib3.exceptions import NewConnectionError

from superset.db_engine_specs.base import BaseEngineSpec
from superset.databases.utils import make_url_safe
from superset.db_engine_specs.base import (
BaseEngineSpec,
BasicParametersMixin,
BasicParametersType,
BasicPropertiesType,
)
from superset.db_engine_specs.exceptions import SupersetDBAPIDatabaseError
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.extensions import cache_manager
from superset.utils.core import GenericDataType
from superset.utils.hashing import md5_sha_from_str
from superset.utils.network import is_hostname_valid, is_port_open

if TYPE_CHECKING:
# prevent circular imports
from superset.models.core import Database

logger = logging.getLogger(__name__)


class ClickHouseEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
"""Dialect for ClickHouse analytical DB."""

engine = "clickhouse"
engine_name = "ClickHouse"
class ClickHouseBaseEngineSpec(BaseEngineSpec):
"""Shared engine spec for ClickHouse."""

time_secondary_columns = True
time_groupby_inline = True
Expand All @@ -56,8 +70,78 @@ class ClickHouseEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
"P1Y": "toStartOfYear(toDateTime({col}))",
}

_show_functions_column = "name"
column_type_mappings = (
(
re.compile(r".*Enum.*", re.IGNORECASE),
types.String(),
GenericDataType.STRING,
),
(
re.compile(r".*Array.*", re.IGNORECASE),
types.String(),
GenericDataType.STRING,
),
(
re.compile(r".*UUID.*", re.IGNORECASE),
types.String(),
GenericDataType.STRING,
),
(
re.compile(r".*Bool.*", re.IGNORECASE),
types.Boolean(),
GenericDataType.BOOLEAN,
),
(
re.compile(r".*String.*", re.IGNORECASE),
types.String(),
GenericDataType.STRING,
),
(
re.compile(r".*Int\d+.*", re.IGNORECASE),
types.INTEGER(),
GenericDataType.NUMERIC,
),
(
re.compile(r".*Decimal.*", re.IGNORECASE),
types.DECIMAL(),
GenericDataType.NUMERIC,
),
(
re.compile(r".*U?DateTime.*", re.IGNORECASE),
types.DateTime(),
GenericDataType.TEMPORAL,
),
(
re.compile(r".*U?Date.*", re.IGNORECASE),
types.Date(),
GenericDataType.TEMPORAL,
),
)

@classmethod
def epoch_to_dttm(cls) -> str:
return "{col}"

@classmethod
def convert_dttm(
cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
) -> Optional[str]:
sqla_type = cls.get_sqla_column_type(target_type)

if isinstance(sqla_type, types.Date):
return f"toDate('{dttm.date().isoformat()}')"
if isinstance(sqla_type, types.DateTime):
return f"""toDateTime('{dttm.isoformat(sep=" ", timespec="seconds")}')"""
return None


class ClickHouseEngineSpec(ClickHouseBaseEngineSpec):
"""Engine spec for clickhouse_sqlalchemy connector"""

engine = "clickhouse"
engine_name = "ClickHouse"

_show_functions_column = "name"
supports_file_upload = False

@classmethod
Expand All @@ -73,21 +157,9 @@ def get_dbapi_mapped_exception(cls, exception: Exception) -> Exception:
return exception
return new_exception(str(exception))

@classmethod
def convert_dttm(
cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
) -> Optional[str]:
sqla_type = cls.get_sqla_column_type(target_type)

if isinstance(sqla_type, types.Date):
return f"toDate('{dttm.date().isoformat()}')"
if isinstance(sqla_type, types.DateTime):
return f"""toDateTime('{dttm.isoformat(sep=" ", timespec="seconds")}')"""
return None

@classmethod
@cache_manager.cache.memoize()
def get_function_names(cls, database: "Database") -> List[str]:
def get_function_names(cls, database: Database) -> List[str]:
"""
Get a list of function names that are able to be called on the database.
Used for SQL Lab autocomplete.
Expand Down Expand Up @@ -123,3 +195,188 @@ def get_function_names(cls, database: "Database") -> List[str]:

# otherwise, return no function names to prevent errors
return []


class ClickHouseParametersSchema(Schema):
username = fields.String(allow_none=True, description=__("Username"))
password = fields.String(allow_none=True, description=__("Password"))
host = fields.String(required=True, description=__("Hostname or IP address"))
port = fields.Integer(
allow_none=True,
description=__("Database port"),
validate=Range(min=0, max=65535),
)
database = fields.String(allow_none=True, description=__("Database name"))
encryption = fields.Boolean(
default=True, description=__("Use an encrypted connection to the database")
)
query = fields.Dict(
keys=fields.Str(), values=fields.Raw(), description=__("Additional parameters")
)


try:
from clickhouse_connect.cc_superset.datatypes import configure_types
from clickhouse_connect.common import set_setting

configure_types()
set_setting(
"product_name",
f"superset/{current_app.config.get('VERSION_STRING', 'dev')}",
)
except ImportError: # ClickHouse Connect not installed, do nothing
pass


class ClickHouseConnectEngineSpec(ClickHouseEngineSpec, BasicParametersMixin):
"""Engine spec for clickhouse-connect connector"""

engine = "clickhousedb"
engine_name = "ClickHouse Connect"

default_driver = "connect"
_function_names: List[str] = []

sqlalchemy_uri_placeholder = (
"clickhousedb://user:password@host[:port][/dbname][?secure=value&=value...]"
)
parameters_schema = ClickHouseParametersSchema()
encryption_parameters = {"secure": "true"}

@classmethod
def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]:
return {}

@classmethod
def get_dbapi_mapped_exception(cls, exception: Exception) -> Exception:
new_exception = cls.get_dbapi_exception_mapping().get(type(exception))
if new_exception == SupersetDBAPIDatabaseError:
return SupersetDBAPIDatabaseError("Connection failed")
if not new_exception:
return exception
return new_exception(str(exception))

@classmethod
def get_function_names(cls, database: Database) -> List[str]:
# pylint: disable=import-outside-toplevel,import-error
from clickhouse_connect.driver.exceptions import ClickHouseError

if cls._function_names:
return cls._function_names
try:
names = database.get_df(
"SELECT name FROM system.functions UNION ALL "
+ "SELECT name FROM system.table_functions LIMIT 10000"
)["name"].tolist()
cls._function_names = names
return names
except ClickHouseError:
logger.exception("Error retrieving system.functions")
return []

@classmethod
def get_datatype(cls, type_code: str) -> str:
# keep it lowercase, as ClickHouse types aren't typical SHOUTCASE ANSI SQL
return type_code

@classmethod
def build_sqlalchemy_uri( # pylint: disable=unused-argument
cls,
parameters: BasicParametersType,
encrypted_extra: Optional[Dict[str, str]] = None,
) -> str:
url_params = parameters.copy()
if url_params.get("encryption"):
query = parameters.get("query", {}).copy()
query.update(cls.encryption_parameters)
url_params["query"] = query
if not url_params.get("database"):
url_params["database"] = "__default__"
url_params.pop("encryption", None)
return str(URL(f"{cls.engine}+{cls.default_driver}", **url_params))

@classmethod
def get_parameters_from_uri(
cls, uri: str, encrypted_extra: Optional[Dict[str, Any]] = None
) -> BasicParametersType:
url = make_url_safe(uri)
query = url.query
if "secure" in query:
encryption = url.query.get("secure") == "true"
query.pop("secure")
else:
encryption = False
return BasicParametersType(
username=url.username,
password=url.password,
host=url.host,
port=url.port,
database="" if url.database == "__default__" else cast(str, url.database),
query=dict(query),
encryption=encryption,
)

@classmethod
def validate_parameters(
cls, properties: BasicPropertiesType
) -> List[SupersetError]:
# pylint: disable=import-outside-toplevel,import-error
from clickhouse_connect.driver import default_port

parameters = properties.get("parameters", {})
host = parameters.get("host", None)
if not host:
return [
SupersetError(
"Hostname is required",
SupersetErrorType.CONNECTION_MISSING_PARAMETERS_ERROR,
ErrorLevel.WARNING,
{"missing": ["host"]},
)
]
if not is_hostname_valid(host):
return [
SupersetError(
"The hostname provided can't be resolved.",
SupersetErrorType.CONNECTION_INVALID_HOSTNAME_ERROR,
ErrorLevel.ERROR,
{"invalid": ["host"]},
)
]
port = parameters.get("port")
if port is None:
port = default_port("http", parameters.get("encryption", False))
try:
port = int(port)
except (ValueError, TypeError):
port = -1
if port <= 0 or port >= 65535:
return [
SupersetError(
"Port must be a valid integer between 0 and 65535 (inclusive).",
SupersetErrorType.CONNECTION_INVALID_PORT_ERROR,
ErrorLevel.ERROR,
{"invalid": ["port"]},
)
]
if not is_port_open(host, port):
return [
SupersetError(
"The port is closed.",
SupersetErrorType.CONNECTION_PORT_CLOSED_ERROR,
ErrorLevel.ERROR,
{"invalid": ["port"]},
)
]
return []

@staticmethod
def _mutate_label(label: str) -> str:
"""
Suffix with the first six characters from the md5 of the label to avoid
collisions with original column names

:param label: Expected expression label
:return: Conditionally mutated label
"""
return f"{label}_{md5_sha_from_str(label)[:6]}"
Loading