diff --git a/setup.cfg b/setup.cfg index 9a108f76a480..20bf9082a28b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -30,7 +30,7 @@ combine_as_imports = true include_trailing_comma = true line_length = 88 known_first_party = superset -known_third_party =alembic,apispec,backoff,bleach,cachelib,celery,click,colorama,cron_descriptor,croniter,cryptography,dateutil,deprecation,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_jwt_extended,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,freezegun,geohash,geopy,graphlib,holidays,humanize,isodate,jinja2,jwt,markdown,markupsafe,marshmallow,marshmallow_enum,msgpack,numpy,pandas,parameterized,parsedatetime,pgsanity,pkg_resources,polyline,prison,progress,pyarrow,pyhive,pyparsing,pytest,pytest_mock,pytz,redis,requests,selenium,setuptools,simplejson,slack,sqlalchemy,sqlalchemy_utils,sqlparse,typing_extensions,urllib3,werkzeug,wtforms,wtforms_json,yaml +known_third_party =alembic,apispec,babel,backoff,bleach,cachelib,celery,click,colorama,cron_descriptor,croniter,cryptography,dateutil,deprecation,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_jwt_extended,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,freezegun,geohash,geopy,graphlib,holidays,humanize,isodate,jinja2,jwt,markdown,markupsafe,marshmallow,marshmallow_enum,msgpack,numpy,pandas,parameterized,parsedatetime,pgsanity,pkg_resources,polyline,prison,progress,pyarrow,pyhive,pyparsing,pytest,pytest_mock,pytz,redis,requests,selenium,setuptools,simplejson,slack,sqlalchemy,sqlalchemy_utils,sqlparse,typing_extensions,urllib3,werkzeug,wtforms,wtforms_json,yaml multi_line_output = 3 order_by_type = false diff --git a/superset/databases/commands/validate.py b/superset/databases/commands/validate.py index e2dcc581d7fb..ac5816248838 100644 --- a/superset/databases/commands/validate.py +++ b/superset/databases/commands/validate.py @@ -36,7 +36,7 @@ from superset.extensions import event_logger from superset.models.core import Database -BYPASS_VALIDATION_ENGINES = {"bigquery"} +BYPASS_VALIDATION_ENGINES = {"bigquery", "snowflake"} class ValidateDatabaseParametersCommand(BaseCommand): diff --git a/superset/db_engine_specs/snowflake.py b/superset/db_engine_specs/snowflake.py index 6dd85706562b..fa31c2394ef3 100644 --- a/superset/db_engine_specs/snowflake.py +++ b/superset/db_engine_specs/snowflake.py @@ -17,14 +17,18 @@ import json import re from datetime import datetime -from typing import Any, Dict, Optional, Pattern, Tuple, TYPE_CHECKING +from typing import Any, Dict, List, Optional, Pattern, Tuple, TYPE_CHECKING, TypedDict from urllib import parse +from apispec import APISpec +from apispec.ext.marshmallow import MarshmallowPlugin +from babel.core import default_locale from flask_babel import gettext as __ +from marshmallow import fields, Schema from sqlalchemy.engine.url import URL from superset.db_engine_specs.postgres import PostgresBaseEngineSpec -from superset.errors import SupersetErrorType +from superset.errors import SupersetError, SupersetErrorType from superset.models.sql_lab import Query from superset.utils import core as utils @@ -42,12 +46,34 @@ ) +class SnowflakeParametersSchema(Schema): + username = fields.Str(required=True) + password = fields.Str(required=True) + account = fields.Str(required=True) + database = fields.Str(required=True) + role = fields.Str(required=True) + warehouse = fields.Str(required=True) + + +class SnowflakeParametersType(TypedDict): + username: str + password: str + account: str + database: str + role: str + warehouse: str + + class SnowflakeEngineSpec(PostgresBaseEngineSpec): engine = "snowflake" engine_name = "Snowflake" force_column_alias_quotes = True max_column_name_length = 256 + parameters_schema = SnowflakeParametersSchema() + default_driver = "snowflake" + sqlalchemy_uri_placeholder = "snowflake://" + _time_grain_expressions = { None: "{col}", "PT1S": "DATE_TRUNC('SECOND', {col})", @@ -160,3 +186,59 @@ def cancel_query(cls, cursor: Any, query: Query, cancel_query_id: str) -> bool: return False return True + + @classmethod + def build_sqlalchemy_uri( + cls, + parameters: SnowflakeParametersType, + encrypted_extra: Optional[Dict[str, Any]] = None, + ) -> str: + query = parameters.get("query", {}) + query_params = urllib.parse.urlencode(query) + + if not encrypted_extra: + raise ValidationError("Missing service credentials") + + project_id = encrypted_extra.get("credentials_info", {}).get("project_id") + + if project_id: + return f"{cls.default_driver}://{project_id}/?{query_params}" + + raise ValidationError("Invalid service credentials") + + @classmethod + def get_parameters_from_uri( + cls, uri: str, encrypted_extra: Optional[Dict[str, str]] = None + ) -> Any: + value = make_url(uri) + + # Building parameters from encrypted_extra and uri + if encrypted_extra: + return {**encrypted_extra, "query": value.query} + + raise ValidationError("Invalid service credentials") + + @classmethod + def validate_parameters( + cls, parameters: SnowflakeParametersType # pylint: disable=unused-argument + ) -> List[SupersetError]: + return [] + + @classmethod + def parameters_json_schema(cls) -> Any: + """ + Return configuration parameters as OpenAPI. + """ + if not cls.parameters_schema: + return None + + ma_plugin = MarshmallowPlugin() + spec = APISpec( + title="Database Parameters", + version="1.0.0", + openapi_version="3.0.0", + plugins=[ma_plugin], + ) + + spec.components.schema(cls.__name__, schema=cls.parameters_schema) + return spec.to_dict()["components"]["schemas"][cls.__name__]