Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add SSL certificate validation for Druid #9396

Merged
merged 5 commits into from Mar 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/installation.rst
Expand Up @@ -729,6 +729,12 @@ The native Druid connector (behind the ``DRUID_IS_ACTIVE`` feature flag)
is slowly getting deprecated in favor of the SQLAlchemy/DBAPI connector made
available in the ``pydruid`` library.

To use a custom SSL certificate to validate HTTPS requests, the certificate
contents can be entered in the ``Root Certificate`` field in the Database
dialog. When using a custom certificate, ``pydruid`` will automatically use
``https`` scheme. To disable SSL verification add the following to extras:
``engine_params": {"connect_args": {"scheme": "https", "ssl_verify_cert": false}}``

Dremio
------

Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Expand Up @@ -45,7 +45,7 @@ combine_as_imports = true
include_trailing_comma = true
line_length = 88
known_first_party = superset
known_third_party =alembic,backoff,bleach,celery,click,colorama,contextlib2,croniter,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,geohash,geopy,humanize,isodate,jinja2,markdown,markupsafe,marshmallow,msgpack,numpy,pandas,parsedatetime,pathlib2,polyline,prison,pyarrow,pyhive,pytz,retry,selenium,setuptools,simplejson,sphinx_rtd_theme,sqlalchemy,sqlalchemy_utils,sqlparse,werkzeug,wtforms,wtforms_json,yaml
known_third_party =alembic,backoff,bleach,celery,click,colorama,contextlib2,croniter,cryptography,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,geohash,geopy,humanize,isodate,jinja2,markdown,markupsafe,marshmallow,msgpack,numpy,pandas,parsedatetime,pathlib2,polyline,prison,pyarrow,pyhive,pytz,retry,selenium,setuptools,simplejson,sphinx_rtd_theme,sqlalchemy,sqlalchemy_utils,sqlparse,werkzeug,wtforms,wtforms_json,yaml
multi_line_output = 3
order_by_type = false

Expand Down
5 changes: 5 additions & 0 deletions superset/config.py
Expand Up @@ -797,6 +797,11 @@ class CeleryConfig: # pylint: disable=too-few-public-methods
# Typically these should not be allowed.
PREVENT_UNSAFE_DB_CONNECTIONS = True

# Path used to store SSL certificates that are generated when using custom certs.
# Defaults to temporary directory.
# Example: SSL_CERT_PATH = "/certs"
SSL_CERT_PATH: Optional[str] = None

# SIP-15 should be enabled for all new Superset deployments which ensures that the time
# range endpoints adhere to [start, end). For existing deployments admins should provide
# a dedicated period of time to allow chart producers to update their charts before
Expand Down
22 changes: 22 additions & 0 deletions superset/db_engine_specs/base.py
Expand Up @@ -16,6 +16,8 @@
# under the License.
# pylint: disable=unused-argument
import hashlib
import json
import logging
import os
import re
from contextlib import closing
Expand Down Expand Up @@ -59,6 +61,8 @@
)
from superset.models.core import Database # pylint: disable=unused-import

logger = logging.getLogger()


class TimeGrain(NamedTuple): # pylint: disable=too-few-public-methods
name: str # TODO: redundant field, remove
Expand Down Expand Up @@ -959,3 +963,21 @@ def mutate_db_for_connection_test(database: "Database") -> None:
:param database: instance to be mutated
"""
return None

@staticmethod
def get_extra_params(database: "Database") -> Dict[str, Any]:
"""
Some databases require adding elements to connection parameters,
like passing certificates to `extra`. This can be done here.

:param database: database instance from which to extract extras
:raises CertificateException: If certificate is not valid/unparseable
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it's nice to add a :raises: on the docstring when a method can raise an exception

extra: Dict[str, Any] = {}
if database.extra:
try:
extra = json.loads(database.extra)
except json.JSONDecodeError as e:
logger.error(e)
raise e
return extra
32 changes: 31 additions & 1 deletion superset/db_engine_specs/druid.py
Expand Up @@ -14,14 +14,20 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import TYPE_CHECKING
import json
import logging
from typing import Any, Dict, TYPE_CHECKING

from superset.db_engine_specs.base import BaseEngineSpec
from superset.utils import core as utils

if TYPE_CHECKING:
from superset.connectors.sqla.models import ( # pylint: disable=unused-import
TableColumn,
)
from superset.models.core import Database # pylint: disable=unused-import

logger = logging.getLogger()


class DruidEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
Expand All @@ -47,3 +53,27 @@ class DruidEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
def alter_new_orm_column(cls, orm_col: "TableColumn") -> None:
if orm_col.column_name == "__time":
orm_col.is_dttm = True

@staticmethod
def get_extra_params(database: "Database") -> Dict[str, Any]:
"""
For Druid, the path to a SSL certificate is placed in `connect_args`.

:param database: database instance from which to extract extras
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it's nice to add a :raises: on the docstring when a method can raise an exception

:raises CertificateException: If certificate is not valid/unparseable
"""
try:
extra = json.loads(database.extra or "{}")
except json.JSONDecodeError as e:
logger.error(e)
raise e

if database.server_cert:
engine_params = extra.get("engine_params", {})
connect_args = engine_params.get("connect_args", {})
connect_args["scheme"] = "https"
path = utils.create_ssl_cert_file(database.server_cert)
connect_args["ssl_verify_cert"] = path
villebro marked this conversation as resolved.
Show resolved Hide resolved
engine_params["connect_args"] = connect_args
extra["engine_params"] = engine_params
return extra
4 changes: 4 additions & 0 deletions superset/exceptions.py
Expand Up @@ -60,5 +60,9 @@ class SpatialException(SupersetException):
pass


class CertificateException(SupersetException):
pass


class DatabaseNotFound(SupersetException):
status = 400
@@ -0,0 +1,47 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""add certificate to dbs

Revision ID: b5998378c225
Revises: 72428d1ea401
Create Date: 2020-03-25 10:49:10.883065

"""

# revision identifiers, used by Alembic.
revision = "b5998378c225"
down_revision = "72428d1ea401"

from typing import Dict

import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects.postgresql.base import PGDialect
from sqlalchemy_utils import EncryptedType


def upgrade():
kwargs: Dict[str, str] = {}
bind = op.get_bind()
op.add_column(
"dbs",
sa.Column("server_cert", EncryptedType(sa.Text()), nullable=True, **kwargs),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense, while we're at it, to create a client_cert field also?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did but took it out, as pydruid doesn't support it yet. However, I opened a PR for that, and will add the client cert field once that PR is merged.

)


def downgrade():
op.drop_column("dbs", "server_cert")
11 changes: 3 additions & 8 deletions superset/models/core.py
Expand Up @@ -139,6 +139,7 @@ class Database(
encrypted_extra = Column(EncryptedType(Text, config["SECRET_KEY"]), nullable=True)
perm = Column(String(1000))
impersonate_user = Column(Boolean, default=False)
server_cert = Column(EncryptedType(Text, config["SECRET_KEY"]), nullable=True)
export_fields = [
"database_name",
"sqlalchemy_uri",
Expand Down Expand Up @@ -309,6 +310,7 @@ def get_sqla_engine(
)
if configuration:
connect_args["configuration"] = configuration
if connect_args:
params["connect_args"] = connect_args

params.update(self.get_encrypted_extra())
Expand Down Expand Up @@ -555,14 +557,7 @@ def grains(self) -> Tuple[TimeGrain, ...]:
return self.db_engine_spec.get_time_grains()

def get_extra(self) -> Dict[str, Any]:
extra: Dict[str, Any] = {}
if self.extra:
try:
extra = json.loads(self.extra)
except json.JSONDecodeError as e:
logger.error(e)
raise e
return extra
return self.db_engine_spec.get_extra_params(self)

def get_encrypted_extra(self):
encrypted_extra = {}
Expand Down
1 change: 1 addition & 0 deletions superset/templates/superset/models/database/add.html
Expand Up @@ -24,4 +24,5 @@
{{ macros.testconn() }}
{{ macros.expand_extra_textarea() }}
{{ macros.expand_encrypted_extra_textarea() }}
{{ macros.expand_server_cert_textarea() }}
{% endblock %}
1 change: 1 addition & 0 deletions superset/templates/superset/models/database/edit.html
Expand Up @@ -24,4 +24,5 @@
{{ macros.testconn() }}
{{ macros.expand_extra_textarea() }}
{{ macros.expand_encrypted_extra_textarea() }}
{{ macros.expand_server_cert_textarea() }}
{% endblock %}
7 changes: 7 additions & 0 deletions superset/templates/superset/models/database/macros.html
Expand Up @@ -43,6 +43,7 @@
impersonate_user: $('#impersonate_user').is(':checked'),
extras: extra ? JSON.parse(extra) : {},
encrypted_extra: encryptedExtra ? JSON.parse(encryptedExtra) : {},
server_cert: $("#server_cert").val(),
})
} catch(parse_error){
alert("Malformed JSON in the extras field: " + parse_error);
Expand Down Expand Up @@ -81,3 +82,9 @@
$('#encrypted_extra').attr('rows', '5');
</script>
{% endmacro %}

{% macro expand_server_cert_textarea() %}
<script>
$('#server_cert').attr('rows', '5');
</script>
{% endmacro %}
52 changes: 50 additions & 2 deletions superset/utils/core.py
Expand Up @@ -19,12 +19,13 @@
import decimal
import errno
import functools
import hashlib
import json
import logging
import os
import re
import signal
import smtplib
import tempfile
import traceback
import uuid
import zlib
Expand All @@ -45,6 +46,9 @@
import pandas as pd
import parsedatetime
import sqlalchemy as sa
from cryptography import x509
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.backends.openssl.x509 import _Certificate
from dateutil.parser import parse
from dateutil.relativedelta import relativedelta
from flask import current_app, flash, Flask, g, Markup, render_template
Expand All @@ -56,7 +60,11 @@
from sqlalchemy.sql.type_api import Variant
from sqlalchemy.types import TEXT, TypeDecorator

from superset.exceptions import SupersetException, SupersetTimeoutException
from superset.exceptions import (
CertificateException,
SupersetException,
SupersetTimeoutException,
)
from superset.utils.dates import datetime_to_epoch, EPOCH

try:
Expand Down Expand Up @@ -1163,6 +1171,46 @@ def get_username() -> Optional[str]:
return None


def parse_ssl_cert(certificate: str) -> _Certificate:
"""
Parses the contents of a certificate and returns a valid certificate object
if valid.

:param certificate: Contents of certificate file
:return: Valid certificate instance
:raises CertificateException: If certificate is not valid/unparseable
"""
try:
return x509.load_pem_x509_certificate(
certificate.encode("utf-8"), default_backend()
)
except ValueError as e:
raise CertificateException("Invalid certificate")


def create_ssl_cert_file(certificate: str) -> str:
"""
This creates a certificate file that can be used to validate HTTPS
sessions. A certificate is only written to disk once; on subsequent calls,
only the path of the existing certificate is returned.

:param certificate: The contents of the certificate
:return: The path to the certificate file
:raises CertificateException: If certificate is not valid/unparseable
"""
filename = f"{hashlib.md5(certificate.encode('utf-8')).hexdigest()}.crt"
cert_dir = current_app.config["SSL_CERT_PATH"]
path = cert_dir if cert_dir else tempfile.gettempdir()
path = os.path.join(path, filename)
if not os.path.exists(path):
# Validate certificate prior to persisting to temporary directory
parse_ssl_cert(certificate)
cert_file = open(path, "w")
cert_file.write(certificate)
cert_file.close()
return path


def MediumText() -> Variant:
return Text().with_variant(MEDIUMTEXT(), "mysql")

Expand Down
13 changes: 13 additions & 0 deletions superset/views/core.py
Expand Up @@ -67,6 +67,7 @@
from superset.connectors.sqla.models import AnnotationDatasource
from superset.constants import RouteMethod
from superset.exceptions import (
CertificateException,
DatabaseNotFound,
SupersetException,
SupersetSecurityException,
Expand Down Expand Up @@ -1374,6 +1375,7 @@ def testconn(self):
# this is the database instance that will be tested
database = models.Database(
# extras is sent as json, but required to be a string in the Database model
server_cert=request.json.get("server_cert"),
extra=json.dumps(request.json.get("extras", {})),
impersonate_user=request.json.get("impersonate_user"),
encrypted_extra=json.dumps(request.json.get("encrypted_extra", {})),
Expand All @@ -1387,6 +1389,17 @@ def testconn(self):
with closing(engine.connect()) as conn:
conn.scalar(select([1]))
return json_success('"OK"')
except CertificateException as e:
logger.info("Invalid certificate %s", e)
return json_error_response(
_(
"Invalid certificate. "
"Please make sure the certificate begins with\n"
"-----BEGIN CERTIFICATE-----\n"
"and ends with \n"
"-----END CERTIFICATE-----"
)
)
except NoSuchModuleError as e:
logger.info("Invalid driver %s", e)
driver_name = make_url(uri).drivername
Expand Down