From 830a28355cb02549b05e4ee2c2284aa3053f3a62 Mon Sep 17 00:00:00 2001 From: hughhhh Date: Wed, 19 Oct 2022 19:27:55 -0400 Subject: [PATCH 01/75] save --- superset/databases/models.py | 46 ++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 superset/databases/models.py diff --git a/superset/databases/models.py b/superset/databases/models.py new file mode 100644 index 000000000000..2595ca452b40 --- /dev/null +++ b/superset/databases/models.py @@ -0,0 +1,46 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import List + +import sqlalchemy as sa +from flask_appbuilder import Model +from sqlalchemy.orm import backref, relationship + +from superset.models.core import Database +from superset.models.helpers import ( + AuditMixinNullable, + ExtraJSONMixin, + ImportExportMixin, +) + + +class SSHTunnelConfiguration( + Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin +): + """ + A table/view in a database. + """ + + __tablename__ = "sl_datasets" + + id = sa.Column(sa.Integer, primary_key=True) + database_id = sa.Column(sa.Integer, sa.ForeignKey("dbs.id"), nullable=False) + database: Database = relationship( + "Database", + backref=backref("datasets", cascade="all, delete-orphan"), + foreign_keys=[database_id], + ) From 2c1e7363c4c94b5a0a7a3094e911190a86eac9b3 Mon Sep 17 00:00:00 2001 From: hughhhh Date: Thu, 20 Oct 2022 12:01:38 -0400 Subject: [PATCH 02/75] create migration --- superset/databases/models.py | 30 +++++- ...c8595_create_ssh_tunnel_credentials_tbl.py | 98 +++++++++++++++++++ 2 files changed, 124 insertions(+), 4 deletions(-) create mode 100644 superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py diff --git a/superset/databases/models.py b/superset/databases/models.py index 2595ca452b40..f1dfe5f1aa10 100644 --- a/superset/databases/models.py +++ b/superset/databases/models.py @@ -20,6 +20,7 @@ from flask_appbuilder import Model from sqlalchemy.orm import backref, relationship +from superset import app from superset.models.core import Database from superset.models.helpers import ( AuditMixinNullable, @@ -27,20 +28,41 @@ ImportExportMixin, ) +app_config = app.config -class SSHTunnelConfiguration( + +class SSHTunnelCredentials( Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin ): """ - A table/view in a database. + A ssh tunnel configuration in a database. """ - __tablename__ = "sl_datasets" + __tablename__ = "ssh_tunnel_credentials" id = sa.Column(sa.Integer, primary_key=True) database_id = sa.Column(sa.Integer, sa.ForeignKey("dbs.id"), nullable=False) database: Database = relationship( "Database", - backref=backref("datasets", cascade="all, delete-orphan"), + backref=backref("ssh_tunnel_credentials", cascade="all, delete-orphan"), foreign_keys=[database_id], ) + + server_address = sa.Column(sa.EncryptedType(sa.String, app_config["SECRET_KEY"])) + server_port = sa.Column(sa.EncryptedType(sa.String, app_config["SECRET_KEY"])) + + # basic authentication + username = sa.Column( + sa.EncryptedType(sa.String, app_config["SECRET_KEY"]), nullable=True + ) + password = sa.Column( + sa.EncryptedType(sa.String, app_config["SECRET_KEY"]), nullable=True + ) + + # password protected pkey authentication + pkey = sa.Column( + sa.EncryptedType(sa.String, app_config["SECRET_KEY"]), nullable=True + ) + private_key = sa.Column( + sa.EncryptedType(sa.String, app_config["SECRET_KEY"]), nullable=True + ) diff --git a/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py b/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py new file mode 100644 index 000000000000..fa448a688872 --- /dev/null +++ b/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py @@ -0,0 +1,98 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""create_ssh_tunnel_credentials_tbl + +Revision ID: f3c2d8ec8595 +Revises: deb4c9d4a4ef +Create Date: 2022-10-20 10:48:08.722861 + +""" + +# revision identifiers, used by Alembic. +revision = "f3c2d8ec8595" +down_revision = "deb4c9d4a4ef" + +from uuid import uuid4 + +import sqlalchemy as sa +from alembic import op +from sqlalchemy_utils import UUIDType + +from superset import app + +app_config = app.config + + +def upgrade(): + op.create_table( + "ssh_tunnel_credential", + # AuditMixinNullable + sa.Column("created_on", sa.DateTime(), nullable=True), + sa.Column("changed_on", sa.DateTime(), nullable=True), + sa.Column("created_by_fk", sa.Integer(), nullable=True), + sa.Column("changed_by_fk", sa.Integer(), nullable=True), + # ExtraJSONMixin + sa.Column("extra_json", sa.Text(), nullable=True), + # ImportExportMixin + sa.Column("uuid", UUIDType(binary=True), primary_key=False, default=uuid4), + # specific to model + sa.Column("server_port", sa.EncryptedType(sa.String, app_config["SECRET_KEY"])), + sa.Column( + "server_address", sa.EncryptedType(sa.String, app_config["SECRET_KEY"]) + ), + sa.Column( + "username", + sa.EncryptedType(sa.String, app_config["SECRET_KEY"]), + nullable=True, + ), + sa.Column( + "password", + sa.EncryptedType(sa.String, app_config["SECRET_KEY"]), + nullable=True, + ), + sa.Column( + "pkey", sa.EncryptedType(sa.String, app_config["SECRET_KEY"]), nullable=True + ), + sa.Column( + "private_key_password", + sa.EncryptedType(sa.String, app_config["SECRET_KEY"]), + nullable=True, + ), + ) + + op.create_table( + "database_ssh_tunnel_credential", + sa.Column( + "ssh_tunnel_credential_id", + sa.INTEGER(), + autoincrement=False, + nullable=False, + ), + sa.Column("database_id", sa.INTEGER(), autoincrement=False, nullable=False), + sa.ForeignKeyConstraint( + ["database_id"], ["dbs.id"], name="database_ssh_tunnel_credential_ibfk_1" + ), + sa.ForeignKeyConstraint( + ["ssh_tunnel_credentials_id"], + ["ssh_tunnel_credentials.id"], + name="database_ssh_tunnel_credential_ibfk_2", + ), + ) + + +def downgrade(): + op.drop_table("ssh_tunnel_credentials") From f78df837462871971790852432b44b29b5256793 Mon Sep 17 00:00:00 2001 From: hughhhh Date: Fri, 21 Oct 2022 09:38:08 -0700 Subject: [PATCH 03/75] created schema and rename --- superset/databases/models.py | 4 ++-- superset/databases/schemas.py | 20 ++++++++++++++++++++ 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/superset/databases/models.py b/superset/databases/models.py index f1dfe5f1aa10..f13c04350fc1 100644 --- a/superset/databases/models.py +++ b/superset/databases/models.py @@ -60,9 +60,9 @@ class SSHTunnelCredentials( ) # password protected pkey authentication - pkey = sa.Column( + private_key = sa.Column( sa.EncryptedType(sa.String, app_config["SECRET_KEY"]), nullable=True ) - private_key = sa.Column( + private_key_password = sa.Column( sa.EncryptedType(sa.String, app_config["SECRET_KEY"]), nullable=True ) diff --git a/superset/databases/schemas.py b/superset/databases/schemas.py index dafd1ba7fc71..bb1153b76ddc 100644 --- a/superset/databases/schemas.py +++ b/superset/databases/schemas.py @@ -676,6 +676,26 @@ class EncryptedDict(EncryptedField, fields.Dict): pass +class DatabaseSSHTunnelCredentials(Schema): + id = fields.Integer() + database_id = fields.Integer() + + server_address = fields.String() + server_port = fields.Integer() + username = fields.String() + + # Basic Authentication + password = fields.String(required=False) + + # password protected private key authentication + private_key = fields.String(required=False) + private_key_password = fields.String(required=False) + + # remote binding port + bind_host = fields.String() + bind_port = fields.Integer() + + def encrypted_field_properties(self, field: Any, **_) -> Dict[str, Any]: # type: ignore ret = {} if isinstance(field, EncryptedField): From d482df46c166338c0c581bb971d50506fbbef7ba Mon Sep 17 00:00:00 2001 From: hughhhh Date: Fri, 21 Oct 2022 11:36:27 -0700 Subject: [PATCH 04/75] linting --- superset/databases/models.py | 4 +--- ..._f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py | 11 ++++++----- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/superset/databases/models.py b/superset/databases/models.py index f13c04350fc1..4da6237bbd2a 100644 --- a/superset/databases/models.py +++ b/superset/databases/models.py @@ -50,11 +50,9 @@ class SSHTunnelCredentials( server_address = sa.Column(sa.EncryptedType(sa.String, app_config["SECRET_KEY"])) server_port = sa.Column(sa.EncryptedType(sa.String, app_config["SECRET_KEY"])) + username = sa.Column(sa.EncryptedType(sa.String, app_config["SECRET_KEY"])) # basic authentication - username = sa.Column( - sa.EncryptedType(sa.String, app_config["SECRET_KEY"]), nullable=True - ) password = sa.Column( sa.EncryptedType(sa.String, app_config["SECRET_KEY"]), nullable=True ) diff --git a/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py b/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py index fa448a688872..3670723a0781 100644 --- a/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py +++ b/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py @@ -39,7 +39,7 @@ def upgrade(): op.create_table( - "ssh_tunnel_credential", + "ssh_tunnel_credentials", # AuditMixinNullable sa.Column("created_on", sa.DateTime(), nullable=True), sa.Column("changed_on", sa.DateTime(), nullable=True), @@ -49,15 +49,14 @@ def upgrade(): sa.Column("extra_json", sa.Text(), nullable=True), # ImportExportMixin sa.Column("uuid", UUIDType(binary=True), primary_key=False, default=uuid4), - # specific to model - sa.Column("server_port", sa.EncryptedType(sa.String, app_config["SECRET_KEY"])), + # Specific to model sa.Column( "server_address", sa.EncryptedType(sa.String, app_config["SECRET_KEY"]) ), + sa.Column("server_port", sa.EncryptedType(sa.String, app_config["SECRET_KEY"])), sa.Column( "username", sa.EncryptedType(sa.String, app_config["SECRET_KEY"]), - nullable=True, ), sa.Column( "password", @@ -65,7 +64,9 @@ def upgrade(): nullable=True, ), sa.Column( - "pkey", sa.EncryptedType(sa.String, app_config["SECRET_KEY"]), nullable=True + "private_key", + sa.EncryptedType(sa.String, app_config["SECRET_KEY"]), + nullable=True, ), sa.Column( "private_key_password", From 9edb581f103d3e29b40d1c814cbe935a3d36f486 Mon Sep 17 00:00:00 2001 From: hughhhh Date: Fri, 21 Oct 2022 13:15:49 -0700 Subject: [PATCH 05/75] fix encrpytions --- superset/databases/models.py | 13 +++++++------ ...d8ec8595_create_ssh_tunnel_credentials_tbl.py | 16 +++++++--------- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/superset/databases/models.py b/superset/databases/models.py index 4da6237bbd2a..bdc6619ca663 100644 --- a/superset/databases/models.py +++ b/superset/databases/models.py @@ -19,6 +19,7 @@ import sqlalchemy as sa from flask_appbuilder import Model from sqlalchemy.orm import backref, relationship +from sqlalchemy_utils import EncryptedType from superset import app from superset.models.core import Database @@ -48,19 +49,19 @@ class SSHTunnelCredentials( foreign_keys=[database_id], ) - server_address = sa.Column(sa.EncryptedType(sa.String, app_config["SECRET_KEY"])) - server_port = sa.Column(sa.EncryptedType(sa.String, app_config["SECRET_KEY"])) - username = sa.Column(sa.EncryptedType(sa.String, app_config["SECRET_KEY"])) + server_address = sa.Column(EncryptedType(sa.String, app_config["SECRET_KEY"])) + server_port = sa.Column(EncryptedType(sa.String, app_config["SECRET_KEY"])) + username = sa.Column(EncryptedType(sa.String, app_config["SECRET_KEY"])) # basic authentication password = sa.Column( - sa.EncryptedType(sa.String, app_config["SECRET_KEY"]), nullable=True + EncryptedType(sa.String, app_config["SECRET_KEY"]), nullable=True ) # password protected pkey authentication private_key = sa.Column( - sa.EncryptedType(sa.String, app_config["SECRET_KEY"]), nullable=True + EncryptedType(sa.String, app_config["SECRET_KEY"]), nullable=True ) private_key_password = sa.Column( - sa.EncryptedType(sa.String, app_config["SECRET_KEY"]), nullable=True + EncryptedType(sa.String, app_config["SECRET_KEY"]), nullable=True ) diff --git a/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py b/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py index 3670723a0781..9da57b25ab97 100644 --- a/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py +++ b/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py @@ -30,7 +30,7 @@ import sqlalchemy as sa from alembic import op -from sqlalchemy_utils import UUIDType +from sqlalchemy_utils import EncryptedType, UUIDType from superset import app @@ -50,27 +50,25 @@ def upgrade(): # ImportExportMixin sa.Column("uuid", UUIDType(binary=True), primary_key=False, default=uuid4), # Specific to model - sa.Column( - "server_address", sa.EncryptedType(sa.String, app_config["SECRET_KEY"]) - ), - sa.Column("server_port", sa.EncryptedType(sa.String, app_config["SECRET_KEY"])), + sa.Column("server_address", EncryptedType(sa.String, app_config["SECRET_KEY"])), + sa.Column("server_port", EncryptedType(sa.String, app_config["SECRET_KEY"])), sa.Column( "username", - sa.EncryptedType(sa.String, app_config["SECRET_KEY"]), + EncryptedType(sa.String, app_config["SECRET_KEY"]), ), sa.Column( "password", - sa.EncryptedType(sa.String, app_config["SECRET_KEY"]), + EncryptedType(sa.String, app_config["SECRET_KEY"]), nullable=True, ), sa.Column( "private_key", - sa.EncryptedType(sa.String, app_config["SECRET_KEY"]), + EncryptedType(sa.String, app_config["SECRET_KEY"]), nullable=True, ), sa.Column( "private_key_password", - sa.EncryptedType(sa.String, app_config["SECRET_KEY"]), + EncryptedType(sa.String, app_config["SECRET_KEY"]), nullable=True, ), ) From da27d8f8d1452f668308154293791ed60402fece Mon Sep 17 00:00:00 2001 From: hughhhh Date: Mon, 24 Oct 2022 12:55:48 -0400 Subject: [PATCH 06/75] remove map tabl --- ...c8595_create_ssh_tunnel_credentials_tbl.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py b/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py index 9da57b25ab97..27e2e7118243 100644 --- a/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py +++ b/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py @@ -73,25 +73,6 @@ def upgrade(): ), ) - op.create_table( - "database_ssh_tunnel_credential", - sa.Column( - "ssh_tunnel_credential_id", - sa.INTEGER(), - autoincrement=False, - nullable=False, - ), - sa.Column("database_id", sa.INTEGER(), autoincrement=False, nullable=False), - sa.ForeignKeyConstraint( - ["database_id"], ["dbs.id"], name="database_ssh_tunnel_credential_ibfk_1" - ), - sa.ForeignKeyConstraint( - ["ssh_tunnel_credentials_id"], - ["ssh_tunnel_credentials.id"], - name="database_ssh_tunnel_credential_ibfk_2", - ), - ) - def downgrade(): op.drop_table("ssh_tunnel_credentials") From 773a6c86f64445a6d67c68cf7e30c245a45b2cd4 Mon Sep 17 00:00:00 2001 From: hughhhh Date: Tue, 25 Oct 2022 11:44:48 -0400 Subject: [PATCH 07/75] fix linting --- superset/databases/models.py | 1 - 1 file changed, 1 deletion(-) diff --git a/superset/databases/models.py b/superset/databases/models.py index bdc6619ca663..721e77a2c0f8 100644 --- a/superset/databases/models.py +++ b/superset/databases/models.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import List import sqlalchemy as sa from flask_appbuilder import Model From 2f2dda2a12d970b565dcb3590680390d9e51442b Mon Sep 17 00:00:00 2001 From: hughhhh Date: Tue, 25 Oct 2022 11:48:54 -0400 Subject: [PATCH 08/75] add constraint --- ...20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py b/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py index 27e2e7118243..f3801ad7851a 100644 --- a/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py +++ b/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py @@ -49,7 +49,8 @@ def upgrade(): sa.Column("extra_json", sa.Text(), nullable=True), # ImportExportMixin sa.Column("uuid", UUIDType(binary=True), primary_key=False, default=uuid4), - # Specific to model + # SSHTunnelCredentials + sa.Column("database_id", sa.INTEGER(), nullable=True), sa.Column("server_address", EncryptedType(sa.String, app_config["SECRET_KEY"])), sa.Column("server_port", EncryptedType(sa.String, app_config["SECRET_KEY"])), sa.Column( @@ -71,6 +72,7 @@ def upgrade(): EncryptedType(sa.String, app_config["SECRET_KEY"]), nullable=True, ), + sa.PrimaryKeyConstraint("id"), ) From fd0d7f2a52a39a9486e26483cfa1d8074d0afdb1 Mon Sep 17 00:00:00 2001 From: hughhhh Date: Tue, 25 Oct 2022 11:55:05 -0400 Subject: [PATCH 09/75] add fk to migration --- ...0-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py b/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py index f3801ad7851a..5bf50f870d1d 100644 --- a/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py +++ b/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py @@ -50,7 +50,7 @@ def upgrade(): # ImportExportMixin sa.Column("uuid", UUIDType(binary=True), primary_key=False, default=uuid4), # SSHTunnelCredentials - sa.Column("database_id", sa.INTEGER(), nullable=True), + sa.Column("database_id", sa.INTEGER(), sa.ForeignKey("dbs.id"), nullable=True), sa.Column("server_address", EncryptedType(sa.String, app_config["SECRET_KEY"])), sa.Column("server_port", EncryptedType(sa.String, app_config["SECRET_KEY"])), sa.Column( From 158da8d2008fc26eb191a600d538d6796caffc3a Mon Sep 17 00:00:00 2001 From: hughhhh Date: Wed, 26 Oct 2022 14:03:32 -0400 Subject: [PATCH 10/75] init --- superset/connectors/sqla/models.py | 12 +- superset/connectors/sqla/utils.py | 39 +++--- .../databases/commands/test_connection.py | 56 ++++---- superset/databases/commands/validate.py | 46 +++---- .../datasets/commands/importers/v1/utils.py | 3 +- superset/db_engine_specs/base.py | 5 +- superset/sql_lab.py | 126 +++++++++--------- 7 files changed, 154 insertions(+), 133 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 4855dd1af3cc..98aac4906fb7 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -958,13 +958,13 @@ def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]: if self.fetch_values_predicate: qry = qry.where(self.get_fetch_values_predicate()) - engine = self.database.get_sqla_engine() - sql = qry.compile(engine, compile_kwargs={"literal_binds": True}) - sql = self._apply_cte(sql, cte) - sql = self.mutate_query_from_config(sql) + with self.database.get_sqla_engine_with_context() as engine: + sql = qry.compile(engine, compile_kwargs={"literal_binds": True}) + sql = self._apply_cte(sql, cte) + sql = self.mutate_query_from_config(sql) - df = pd.read_sql_query(sql=sql, con=engine) - return df[column_name].to_list() + df = pd.read_sql_query(sql=sql, con=engine) + return df[column_name].to_list() def mutate_query_from_config(self, sql: str) -> str: """Apply config's SQL_QUERY_MUTATOR diff --git a/superset/connectors/sqla/utils.py b/superset/connectors/sqla/utils.py index 8151bfd44b03..05cf8cea1324 100644 --- a/superset/connectors/sqla/utils.py +++ b/superset/connectors/sqla/utils.py @@ -112,7 +112,6 @@ def get_virtual_table_metadata(dataset: SqlaTable) -> List[ResultSetColumnType]: ) db_engine_spec = dataset.database.db_engine_spec - engine = dataset.database.get_sqla_engine(schema=dataset.schema) sql = dataset.get_template_processor().process_template( dataset.sql, **dataset.template_params_dict ) @@ -137,13 +136,18 @@ def get_virtual_table_metadata(dataset: SqlaTable) -> List[ResultSetColumnType]: # TODO(villebro): refactor to use same code that's used by # sql_lab.py:execute_sql_statements try: - with closing(engine.raw_connection()) as conn: - cursor = conn.cursor() - query = dataset.database.apply_limit_to_sql(statements[0], limit=1) - db_engine_spec.execute(cursor, query) - result = db_engine_spec.fetch_data(cursor, limit=1) - result_set = SupersetResultSet(result, cursor.description, db_engine_spec) - cols = result_set.columns + with dataset.database.get_sqla_engine_with_context( + schema=dataset.schema + ) as engine: + with closing(engine.raw_connection()) as conn: + cursor = conn.cursor() + query = dataset.database.apply_limit_to_sql(statements[0], limit=1) + db_engine_spec.execute(cursor, query) + result = db_engine_spec.fetch_data(cursor, limit=1) + result_set = SupersetResultSet( + result, cursor.description, db_engine_spec + ) + cols = result_set.columns except Exception as ex: raise SupersetGenericDBErrorException(message=str(ex)) from ex return cols @@ -155,14 +159,17 @@ def get_columns_description( ) -> List[ResultSetColumnType]: db_engine_spec = database.db_engine_spec try: - with closing(database.get_sqla_engine().raw_connection()) as conn: - cursor = conn.cursor() - query = database.apply_limit_to_sql(query, limit=1) - cursor.execute(query) - db_engine_spec.execute(cursor, query) - result = db_engine_spec.fetch_data(cursor, limit=1) - result_set = SupersetResultSet(result, cursor.description, db_engine_spec) - return result_set.columns + with database.get_sqla_engine_with_context() as engine: + with closing(engine.raw_connection()) as conn: + cursor = conn.cursor() + query = database.apply_limit_to_sql(query, limit=1) + cursor.execute(query) + db_engine_spec.execute(cursor, query) + result = db_engine_spec.fetch_data(cursor, limit=1) + result_set = SupersetResultSet( + result, cursor.description, db_engine_spec + ) + return result_set.columns except Exception as ex: raise SupersetGenericDBErrorException(message=str(ex)) from ex diff --git a/superset/databases/commands/test_connection.py b/superset/databases/commands/test_connection.py index d7f7d90e4922..2865174ff805 100644 --- a/superset/databases/commands/test_connection.py +++ b/superset/databases/commands/test_connection.py @@ -86,7 +86,6 @@ def run(self) -> None: # pylint: disable=too-many-statements database.set_sqlalchemy_uri(uri) database.db_engine_spec.mutate_db_for_connection_test(database) - engine = database.get_sqla_engine() event_logger.log_with_context( action="test_connection_attempt", engine=database.db_engine_spec.__name__, @@ -96,31 +95,36 @@ def ping(engine: Engine) -> bool: with closing(engine.raw_connection()) as conn: return engine.dialect.do_ping(conn) - try: - alive = func_timeout( - int(app.config["TEST_DATABASE_CONNECTION_TIMEOUT"].total_seconds()), - ping, - args=(engine,), - ) - except (sqlite3.ProgrammingError, RuntimeError): - # SQLite can't run on a separate thread, so ``func_timeout`` fails - # RuntimeError catches the equivalent error from duckdb. - alive = engine.dialect.do_ping(engine) - except FunctionTimedOut as ex: - raise SupersetTimeoutException( - error_type=SupersetErrorType.CONNECTION_DATABASE_TIMEOUT, - message=( - "Please check your connection details and database settings, " - "and ensure that your database is accepting connections, " - "then try connecting again." - ), - level=ErrorLevel.ERROR, - extra={"sqlalchemy_uri": database.sqlalchemy_uri}, - ) from ex - except Exception: # pylint: disable=broad-except - alive = False - if not alive: - raise DBAPIError(None, None, None) + with database.get_sqla_engine_with_context() as engine: + try: + alive = func_timeout( + int( + app.config[ + "TEST_DATABASE_CONNECTION_TIMEOUT" + ].total_seconds() + ), + ping, + args=(engine,), + ) + except (sqlite3.ProgrammingError, RuntimeError): + # SQLite can't run on a separate thread, so ``func_timeout`` fails + # RuntimeError catches the equivalent error from duckdb. + alive = engine.dialect.do_ping(engine) + except FunctionTimedOut as ex: + raise SupersetTimeoutException( + error_type=SupersetErrorType.CONNECTION_DATABASE_TIMEOUT, + message=( + "Please check your connection details and database settings, " + "and ensure that your database is accepting connections, " + "then try connecting again." + ), + level=ErrorLevel.ERROR, + extra={"sqlalchemy_uri": database.sqlalchemy_uri}, + ) from ex + except Exception: # pylint: disable=broad-except + alive = False + if not alive: + raise DBAPIError(None, None, None) # Log succesful connection test with engine event_logger.log_with_context( diff --git a/superset/databases/commands/validate.py b/superset/databases/commands/validate.py index a8956257fa28..a92fb79f83ed 100644 --- a/superset/databases/commands/validate.py +++ b/superset/databases/commands/validate.py @@ -101,30 +101,30 @@ def run(self) -> None: database.set_sqlalchemy_uri(sqlalchemy_uri) database.db_engine_spec.mutate_db_for_connection_test(database) - engine = database.get_sqla_engine() - try: - with closing(engine.raw_connection()) as conn: - alive = engine.dialect.do_ping(conn) - except Exception as ex: - url = make_url_safe(sqlalchemy_uri) - context = { - "hostname": url.host, - "password": url.password, - "port": url.port, - "username": url.username, - "database": url.database, - } - errors = database.db_engine_spec.extract_errors(ex, context) - raise DatabaseTestConnectionFailedError(errors) from ex + with database.get_sqla_engine_with_context() as engine: + try: + with closing(engine.raw_connection()) as conn: + alive = engine.dialect.do_ping(conn) + except Exception as ex: + url = make_url_safe(sqlalchemy_uri) + context = { + "hostname": url.host, + "password": url.password, + "port": url.port, + "username": url.username, + "database": url.database, + } + errors = database.db_engine_spec.extract_errors(ex, context) + raise DatabaseTestConnectionFailedError(errors) from ex - if not alive: - raise DatabaseOfflineError( - SupersetError( - message=__("Database is offline."), - error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR, - level=ErrorLevel.ERROR, - ), - ) + if not alive: + raise DatabaseOfflineError( + SupersetError( + message=__("Database is offline."), + error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR, + level=ErrorLevel.ERROR, + ), + ) def validate(self) -> None: database_id = self._properties.get("id") diff --git a/superset/datasets/commands/importers/v1/utils.py b/superset/datasets/commands/importers/v1/utils.py index ba2b7df26174..7d3998b3bb54 100644 --- a/superset/datasets/commands/importers/v1/utils.py +++ b/superset/datasets/commands/importers/v1/utils.py @@ -168,7 +168,8 @@ def load_data( connection = session.connection() else: logger.warning("Loading data outside the import transaction") - connection = database.get_sqla_engine() + with database.get_sqla_engine_with_context() as engine: + connection = engine df.to_sql( dataset.table_name, diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index dabed0c7aeee..9dd7594dc720 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -472,7 +472,10 @@ def get_engine( schema: Optional[str] = None, source: Optional[utils.QuerySource] = None, ) -> Engine: - return database.get_sqla_engine(schema=schema, source=source) + with database.get_sqla_engine_with_context( + schema=schema, source=source + ) as engine: + return engine @classmethod def get_timestamp_expr( diff --git a/superset/sql_lab.py b/superset/sql_lab.py index 96afc7f51ed9..6d9903c8f000 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -463,61 +463,66 @@ def execute_sql_statements( # pylint: disable=too-many-arguments, too-many-loca ) ) - engine = database.get_sqla_engine(query.schema, source=QuerySource.SQL_LAB) - # Sharing a single connection and cursor across the - # execution of all statements (if many) - with closing(engine.raw_connection()) as conn: - # closing the connection closes the cursor as well - cursor = conn.cursor() - cancel_query_id = db_engine_spec.get_cancel_query_id(cursor, query) - if cancel_query_id is not None: - query.set_extra_json_key(cancel_query_key, cancel_query_id) - session.commit() - statement_count = len(statements) - for i, statement in enumerate(statements): - # Check if stopped - session.refresh(query) - if query.status == QueryStatus.STOPPED: - payload.update({"status": query.status}) - return payload - - # For CTAS we create the table only on the last statement - apply_ctas = query.select_as_cta and ( - query.ctas_method == CtasMethod.VIEW - or (query.ctas_method == CtasMethod.TABLE and i == len(statements) - 1) - ) - - # Run statement - msg = f"Running statement {i+1} out of {statement_count}" - logger.info("Query %s: %s", str(query_id), msg) - query.set_extra_json_key("progress", msg) - session.commit() - try: - result_set = execute_sql_statement( - statement, - query, - session, - cursor, - log_params, - apply_ctas, - ) - except SqlLabQueryStoppedException: - payload.update({"status": QueryStatus.STOPPED}) - return payload - except Exception as ex: # pylint: disable=broad-except - msg = str(ex) - prefix_message = ( - f"[Statement {i+1} out of {statement_count}]" - if statement_count > 1 - else "" + with database.get_sqla_engine_with_context( + query.schema, source=QuerySource.SQL_LAB + ) as engine: + # Sharing a single connection and cursor across the + # execution of all statements (if many) + with closing(engine.raw_connection()) as conn: + # closing the connection closes the cursor as well + cursor = conn.cursor() + cancel_query_id = db_engine_spec.get_cancel_query_id(cursor, query) + if cancel_query_id is not None: + query.set_extra_json_key(cancel_query_key, cancel_query_id) + session.commit() + statement_count = len(statements) + for i, statement in enumerate(statements): + # Check if stopped + session.refresh(query) + if query.status == QueryStatus.STOPPED: + payload.update({"status": query.status}) + return payload + + # For CTAS we create the table only on the last statement + apply_ctas = query.select_as_cta and ( + query.ctas_method == CtasMethod.VIEW + or ( + query.ctas_method == CtasMethod.TABLE + and i == len(statements) - 1 + ) ) - payload = handle_query_error( - ex, query, session, payload, prefix_message - ) - return payload - # Commit the connection so CTA queries will create the table. - conn.commit() + # Run statement + msg = f"Running statement {i+1} out of {statement_count}" + logger.info("Query %s: %s", str(query_id), msg) + query.set_extra_json_key("progress", msg) + session.commit() + try: + result_set = execute_sql_statement( + statement, + query, + session, + cursor, + log_params, + apply_ctas, + ) + except SqlLabQueryStoppedException: + payload.update({"status": QueryStatus.STOPPED}) + return payload + except Exception as ex: # pylint: disable=broad-except + msg = str(ex) + prefix_message = ( + f"[Statement {i+1} out of {statement_count}]" + if statement_count > 1 + else "" + ) + payload = handle_query_error( + ex, query, session, payload, prefix_message + ) + return payload + + # Commit the connection so CTA queries will create the table. + conn.commit() # Success, updating the query entry in database query.rows = result_set.size @@ -622,10 +627,11 @@ def cancel_query(query: Query) -> bool: if cancel_query_id is None: return False - engine = query.database.get_sqla_engine(query.schema, source=QuerySource.SQL_LAB) - - with closing(engine.raw_connection()) as conn: - with closing(conn.cursor()) as cursor: - return query.database.db_engine_spec.cancel_query( - cursor, query, cancel_query_id - ) + with query.database.get_sqla_engine_with_context( + query.schema, source=QuerySource.SQL_LAB + ) as engine: + with closing(engine.raw_connection()) as conn: + with closing(conn.cursor()) as cursor: + return query.database.db_engine_spec.cancel_query( + cursor, query, cancel_query_id + ) From face73f23125683931a6629ac220c29a156b351b Mon Sep 17 00:00:00 2001 From: hughhhh Date: Wed, 26 Oct 2022 15:35:29 -0400 Subject: [PATCH 11/75] update all the examples --- superset/examples/bart_lines.py | 47 +++++----- superset/examples/birth_names.py | 38 ++++---- superset/examples/country_map.py | 63 +++++++------- superset/examples/energy.py | 35 ++++---- superset/examples/flights.py | 54 ++++++------ superset/examples/long_lat.py | 86 ++++++++++--------- superset/examples/multiformat_time_series.py | 66 +++++++------- superset/examples/paris.py | 42 ++++----- superset/examples/random_time_series.py | 40 ++++----- superset/examples/sf_population_polygons.py | 42 ++++----- .../examples/supported_charts_dashboard.py | 8 +- superset/examples/world_bank.py | 59 ++++++------- 12 files changed, 294 insertions(+), 286 deletions(-) diff --git a/superset/examples/bart_lines.py b/superset/examples/bart_lines.py index 91257058be75..eb11be2eed1b 100644 --- a/superset/examples/bart_lines.py +++ b/superset/examples/bart_lines.py @@ -29,31 +29,32 @@ def load_bart_lines(only_metadata: bool = False, force: bool = False) -> None: tbl_name = "bart_lines" database = get_example_database() - engine = database.get_sqla_engine() - schema = inspect(engine).default_schema_name - table_exists = database.has_table_by_name(tbl_name) + with database.get_sqla_engine_with_context() as engine: + engine = database.get_sqla_engine() + schema = inspect(engine).default_schema_name + table_exists = database.has_table_by_name(tbl_name) - if not only_metadata and (not table_exists or force): - url = get_example_url("bart-lines.json.gz") - df = pd.read_json(url, encoding="latin-1", compression="gzip") - df["path_json"] = df.path.map(json.dumps) - df["polyline"] = df.path.map(polyline.encode) - del df["path"] + if not only_metadata and (not table_exists or force): + url = get_example_url("bart-lines.json.gz") + df = pd.read_json(url, encoding="latin-1", compression="gzip") + df["path_json"] = df.path.map(json.dumps) + df["polyline"] = df.path.map(polyline.encode) + del df["path"] - df.to_sql( - tbl_name, - engine, - schema=schema, - if_exists="replace", - chunksize=500, - dtype={ - "color": String(255), - "name": String(255), - "polyline": Text, - "path_json": Text, - }, - index=False, - ) + df.to_sql( + tbl_name, + engine, + schema=schema, + if_exists="replace", + chunksize=500, + dtype={ + "color": String(255), + "name": String(255), + "polyline": Text, + "path_json": Text, + }, + index=False, + ) print("Creating table {} reference".format(tbl_name)) table = get_table_connector_registry() diff --git a/superset/examples/birth_names.py b/superset/examples/birth_names.py index f8b8a8ecf7ca..3fc86bc849cc 100644 --- a/superset/examples/birth_names.py +++ b/superset/examples/birth_names.py @@ -76,25 +76,25 @@ def load_data(tbl_name: str, database: Database, sample: bool = False) -> None: pdf.ds = pd.to_datetime(pdf.ds, unit="ms") pdf = pdf.head(100) if sample else pdf - engine = database.get_sqla_engine() - schema = inspect(engine).default_schema_name - - pdf.to_sql( - tbl_name, - database.get_sqla_engine(), - schema=schema, - if_exists="replace", - chunksize=500, - dtype={ - # TODO(bkyryliuk): use TIMESTAMP type for presto - "ds": DateTime if database.backend != "presto" else String(255), - "gender": String(16), - "state": String(10), - "name": String(255), - }, - method="multi", - index=False, - ) + with database.get_sqla_engine_with_context() as engine: + schema = inspect(engine).default_schema_name + + pdf.to_sql( + tbl_name, + database.get_sqla_engine(), + schema=schema, + if_exists="replace", + chunksize=500, + dtype={ + # TODO(bkyryliuk): use TIMESTAMP type for presto + "ds": DateTime if database.backend != "presto" else String(255), + "gender": String(16), + "state": String(10), + "name": String(255), + }, + method="multi", + index=False, + ) print("Done loading table!") print("-" * 80) diff --git a/superset/examples/country_map.py b/superset/examples/country_map.py index 302b55180ea8..4331033ca836 100644 --- a/superset/examples/country_map.py +++ b/superset/examples/country_map.py @@ -39,38 +39,39 @@ def load_country_map_data(only_metadata: bool = False, force: bool = False) -> N """Loading data for map with country map""" tbl_name = "birth_france_by_region" database = database_utils.get_example_database() - engine = database.get_sqla_engine() - schema = inspect(engine).default_schema_name - table_exists = database.has_table_by_name(tbl_name) - if not only_metadata and (not table_exists or force): - url = get_example_url("birth_france_data_for_country_map.csv") - data = pd.read_csv(url, encoding="utf-8") - data["dttm"] = datetime.datetime.now().date() - data.to_sql( - tbl_name, - engine, - schema=schema, - if_exists="replace", - chunksize=500, - dtype={ - "DEPT_ID": String(10), - "2003": BigInteger, - "2004": BigInteger, - "2005": BigInteger, - "2006": BigInteger, - "2007": BigInteger, - "2008": BigInteger, - "2009": BigInteger, - "2010": BigInteger, - "2011": BigInteger, - "2012": BigInteger, - "2013": BigInteger, - "2014": BigInteger, - "dttm": Date(), - }, - index=False, - ) + with database.get_sqla_engine_with_context() as engine: + schema = inspect(engine).default_schema_name + table_exists = database.has_table_by_name(tbl_name) + + if not only_metadata and (not table_exists or force): + url = get_example_url("birth_france_data_for_country_map.csv") + data = pd.read_csv(url, encoding="utf-8") + data["dttm"] = datetime.datetime.now().date() + data.to_sql( + tbl_name, + engine, + schema=schema, + if_exists="replace", + chunksize=500, + dtype={ + "DEPT_ID": String(10), + "2003": BigInteger, + "2004": BigInteger, + "2005": BigInteger, + "2006": BigInteger, + "2007": BigInteger, + "2008": BigInteger, + "2009": BigInteger, + "2010": BigInteger, + "2011": BigInteger, + "2012": BigInteger, + "2013": BigInteger, + "2014": BigInteger, + "dttm": Date(), + }, + index=False, + ) print("Done loading table!") print("-" * 80) diff --git a/superset/examples/energy.py b/superset/examples/energy.py index 72b22525f276..6688e5d08844 100644 --- a/superset/examples/energy.py +++ b/superset/examples/energy.py @@ -41,24 +41,25 @@ def load_energy( """Loads an energy related dataset to use with sankey and graphs""" tbl_name = "energy_usage" database = database_utils.get_example_database() - engine = database.get_sqla_engine() - schema = inspect(engine).default_schema_name - table_exists = database.has_table_by_name(tbl_name) - if not only_metadata and (not table_exists or force): - url = get_example_url("energy.json.gz") - pdf = pd.read_json(url, compression="gzip") - pdf = pdf.head(100) if sample else pdf - pdf.to_sql( - tbl_name, - engine, - schema=schema, - if_exists="replace", - chunksize=500, - dtype={"source": String(255), "target": String(255), "value": Float()}, - index=False, - method="multi", - ) + with database.get_sqla_engine_with_context() as engine: + schema = inspect(engine).default_schema_name + table_exists = database.has_table_by_name(tbl_name) + + if not only_metadata and (not table_exists or force): + url = get_example_url("energy.json.gz") + pdf = pd.read_json(url, compression="gzip") + pdf = pdf.head(100) if sample else pdf + pdf.to_sql( + tbl_name, + engine, + schema=schema, + if_exists="replace", + chunksize=500, + dtype={"source": String(255), "target": String(255), "value": Float()}, + index=False, + method="multi", + ) print("Creating table [wb_health_population] reference") table = get_table_connector_registry() diff --git a/superset/examples/flights.py b/superset/examples/flights.py index 1389c65c9a90..7c8f9802988b 100644 --- a/superset/examples/flights.py +++ b/superset/examples/flights.py @@ -27,35 +27,37 @@ def load_flights(only_metadata: bool = False, force: bool = False) -> None: """Loading random time series data from a zip file in the repo""" tbl_name = "flights" database = database_utils.get_example_database() - engine = database.get_sqla_engine() - schema = inspect(engine).default_schema_name - table_exists = database.has_table_by_name(tbl_name) + with database.get_sqla_engine_with_context() as engine: + schema = inspect(engine).default_schema_name + table_exists = database.has_table_by_name(tbl_name) - if not only_metadata and (not table_exists or force): - flight_data_url = get_example_url("flight_data.csv.gz") - pdf = pd.read_csv(flight_data_url, encoding="latin-1", compression="gzip") + if not only_metadata and (not table_exists or force): + flight_data_url = get_example_url("flight_data.csv.gz") + pdf = pd.read_csv(flight_data_url, encoding="latin-1", compression="gzip") - # Loading airports info to join and get lat/long - airports_url = get_example_url("airports.csv.gz") - airports = pd.read_csv(airports_url, encoding="latin-1", compression="gzip") - airports = airports.set_index("IATA_CODE") + # Loading airports info to join and get lat/long + airports_url = get_example_url("airports.csv.gz") + airports = pd.read_csv(airports_url, encoding="latin-1", compression="gzip") + airports = airports.set_index("IATA_CODE") - pdf[ # pylint: disable=unsupported-assignment-operation,useless-suppression - "ds" - ] = (pdf.YEAR.map(str) + "-0" + pdf.MONTH.map(str) + "-0" + pdf.DAY.map(str)) - pdf.ds = pd.to_datetime(pdf.ds) - pdf.drop(columns=["DAY", "MONTH", "YEAR"]) - pdf = pdf.join(airports, on="ORIGIN_AIRPORT", rsuffix="_ORIG") - pdf = pdf.join(airports, on="DESTINATION_AIRPORT", rsuffix="_DEST") - pdf.to_sql( - tbl_name, - engine, - schema=schema, - if_exists="replace", - chunksize=500, - dtype={"ds": DateTime}, - index=False, - ) + pdf[ # pylint: disable=unsupported-assignment-operation,useless-suppression + "ds" + ] = ( + pdf.YEAR.map(str) + "-0" + pdf.MONTH.map(str) + "-0" + pdf.DAY.map(str) + ) + pdf.ds = pd.to_datetime(pdf.ds) + pdf.drop(columns=["DAY", "MONTH", "YEAR"]) + pdf = pdf.join(airports, on="ORIGIN_AIRPORT", rsuffix="_ORIG") + pdf = pdf.join(airports, on="DESTINATION_AIRPORT", rsuffix="_DEST") + pdf.to_sql( + tbl_name, + engine, + schema=schema, + if_exists="replace", + chunksize=500, + dtype={"ds": DateTime}, + index=False, + ) table = get_table_connector_registry() tbl = db.session.query(table).filter_by(table_name=tbl_name).first() diff --git a/superset/examples/long_lat.py b/superset/examples/long_lat.py index 76f51a615951..88b45548f48d 100644 --- a/superset/examples/long_lat.py +++ b/superset/examples/long_lat.py @@ -39,49 +39,51 @@ def load_long_lat_data(only_metadata: bool = False, force: bool = False) -> None """Loading lat/long data from a csv file in the repo""" tbl_name = "long_lat" database = database_utils.get_example_database() - engine = database.get_sqla_engine() - schema = inspect(engine).default_schema_name - table_exists = database.has_table_by_name(tbl_name) + with database.get_sqla_engine_with_context() as engine: + schema = inspect(engine).default_schema_name + table_exists = database.has_table_by_name(tbl_name) - if not only_metadata and (not table_exists or force): - url = get_example_url("san_francisco.csv.gz") - pdf = pd.read_csv(url, encoding="utf-8", compression="gzip") - start = datetime.datetime.now().replace( - hour=0, minute=0, second=0, microsecond=0 - ) - pdf["datetime"] = [ - start + datetime.timedelta(hours=i * 24 / (len(pdf) - 1)) - for i in range(len(pdf)) - ] - pdf["occupancy"] = [random.randint(1, 6) for _ in range(len(pdf))] - pdf["radius_miles"] = [random.uniform(1, 3) for _ in range(len(pdf))] - pdf["geohash"] = pdf[["LAT", "LON"]].apply(lambda x: geohash.encode(*x), axis=1) - pdf["delimited"] = pdf["LAT"].map(str).str.cat(pdf["LON"].map(str), sep=",") - pdf.to_sql( - tbl_name, - engine, - schema=schema, - if_exists="replace", - chunksize=500, - dtype={ - "longitude": Float(), - "latitude": Float(), - "number": Float(), - "street": String(100), - "unit": String(10), - "city": String(50), - "district": String(50), - "region": String(50), - "postcode": Float(), - "id": String(100), - "datetime": DateTime(), - "occupancy": Float(), - "radius_miles": Float(), - "geohash": String(12), - "delimited": String(60), - }, - index=False, - ) + if not only_metadata and (not table_exists or force): + url = get_example_url("san_francisco.csv.gz") + pdf = pd.read_csv(url, encoding="utf-8", compression="gzip") + start = datetime.datetime.now().replace( + hour=0, minute=0, second=0, microsecond=0 + ) + pdf["datetime"] = [ + start + datetime.timedelta(hours=i * 24 / (len(pdf) - 1)) + for i in range(len(pdf)) + ] + pdf["occupancy"] = [random.randint(1, 6) for _ in range(len(pdf))] + pdf["radius_miles"] = [random.uniform(1, 3) for _ in range(len(pdf))] + pdf["geohash"] = pdf[["LAT", "LON"]].apply( + lambda x: geohash.encode(*x), axis=1 + ) + pdf["delimited"] = pdf["LAT"].map(str).str.cat(pdf["LON"].map(str), sep=",") + pdf.to_sql( + tbl_name, + engine, + schema=schema, + if_exists="replace", + chunksize=500, + dtype={ + "longitude": Float(), + "latitude": Float(), + "number": Float(), + "street": String(100), + "unit": String(10), + "city": String(50), + "district": String(50), + "region": String(50), + "postcode": Float(), + "id": String(100), + "datetime": DateTime(), + "occupancy": Float(), + "radius_miles": Float(), + "geohash": String(12), + "delimited": String(60), + }, + index=False, + ) print("Done loading table!") print("-" * 80) diff --git a/superset/examples/multiformat_time_series.py b/superset/examples/multiformat_time_series.py index 62e16d2cb088..b030bcdb0f23 100644 --- a/superset/examples/multiformat_time_series.py +++ b/superset/examples/multiformat_time_series.py @@ -39,41 +39,41 @@ def load_multiformat_time_series( # pylint: disable=too-many-locals """Loading time series data from a zip file in the repo""" tbl_name = "multiformat_time_series" database = get_example_database() - engine = database.get_sqla_engine() - schema = inspect(engine).default_schema_name - table_exists = database.has_table_by_name(tbl_name) + with database.get_sqla_engine_with_context() as engine: + schema = inspect(engine).default_schema_name + table_exists = database.has_table_by_name(tbl_name) - if not only_metadata and (not table_exists or force): - url = get_example_url("multiformat_time_series.json.gz") - pdf = pd.read_json(url, compression="gzip") - # TODO(bkyryliuk): move load examples data into the pytest fixture - if database.backend == "presto": - pdf.ds = pd.to_datetime(pdf.ds, unit="s") - pdf.ds = pdf.ds.dt.strftime("%Y-%m-%d") - pdf.ds2 = pd.to_datetime(pdf.ds2, unit="s") - pdf.ds2 = pdf.ds2.dt.strftime("%Y-%m-%d %H:%M%:%S") - else: - pdf.ds = pd.to_datetime(pdf.ds, unit="s") - pdf.ds2 = pd.to_datetime(pdf.ds2, unit="s") + if not only_metadata and (not table_exists or force): + url = get_example_url("multiformat_time_series.json.gz") + pdf = pd.read_json(url, compression="gzip") + # TODO(bkyryliuk): move load examples data into the pytest fixture + if database.backend == "presto": + pdf.ds = pd.to_datetime(pdf.ds, unit="s") + pdf.ds = pdf.ds.dt.strftime("%Y-%m-%d") + pdf.ds2 = pd.to_datetime(pdf.ds2, unit="s") + pdf.ds2 = pdf.ds2.dt.strftime("%Y-%m-%d %H:%M%:%S") + else: + pdf.ds = pd.to_datetime(pdf.ds, unit="s") + pdf.ds2 = pd.to_datetime(pdf.ds2, unit="s") - pdf.to_sql( - tbl_name, - engine, - schema=schema, - if_exists="replace", - chunksize=500, - dtype={ - "ds": String(255) if database.backend == "presto" else Date, - "ds2": String(255) if database.backend == "presto" else DateTime, - "epoch_s": BigInteger, - "epoch_ms": BigInteger, - "string0": String(100), - "string1": String(100), - "string2": String(100), - "string3": String(100), - }, - index=False, - ) + pdf.to_sql( + tbl_name, + engine, + schema=schema, + if_exists="replace", + chunksize=500, + dtype={ + "ds": String(255) if database.backend == "presto" else Date, + "ds2": String(255) if database.backend == "presto" else DateTime, + "epoch_s": BigInteger, + "epoch_ms": BigInteger, + "string0": String(100), + "string1": String(100), + "string2": String(100), + "string3": String(100), + }, + index=False, + ) print("Done loading table!") print("-" * 80) diff --git a/superset/examples/paris.py b/superset/examples/paris.py index c32300702852..a54a3706b13c 100644 --- a/superset/examples/paris.py +++ b/superset/examples/paris.py @@ -28,29 +28,29 @@ def load_paris_iris_geojson(only_metadata: bool = False, force: bool = False) -> None: tbl_name = "paris_iris_mapping" database = database_utils.get_example_database() - engine = database.get_sqla_engine() - schema = inspect(engine).default_schema_name - table_exists = database.has_table_by_name(tbl_name) + with database.get_sqla_engine_with_context() as engine: + schema = inspect(engine).default_schema_name + table_exists = database.has_table_by_name(tbl_name) - if not only_metadata and (not table_exists or force): - url = get_example_url("paris_iris.json.gz") - df = pd.read_json(url, compression="gzip") - df["features"] = df.features.map(json.dumps) + if not only_metadata and (not table_exists or force): + url = get_example_url("paris_iris.json.gz") + df = pd.read_json(url, compression="gzip") + df["features"] = df.features.map(json.dumps) - df.to_sql( - tbl_name, - engine, - schema=schema, - if_exists="replace", - chunksize=500, - dtype={ - "color": String(255), - "name": String(255), - "features": Text, - "type": Text, - }, - index=False, - ) + df.to_sql( + tbl_name, + engine, + schema=schema, + if_exists="replace", + chunksize=500, + dtype={ + "color": String(255), + "name": String(255), + "features": Text, + "type": Text, + }, + index=False, + ) print("Creating table {} reference".format(tbl_name)) table = get_table_connector_registry() diff --git a/superset/examples/random_time_series.py b/superset/examples/random_time_series.py index 4a2628df7a07..9a296ec2c471 100644 --- a/superset/examples/random_time_series.py +++ b/superset/examples/random_time_series.py @@ -37,28 +37,28 @@ def load_random_time_series_data( """Loading random time series data from a zip file in the repo""" tbl_name = "random_time_series" database = database_utils.get_example_database() - engine = database.get_sqla_engine() - schema = inspect(engine).default_schema_name - table_exists = database.has_table_by_name(tbl_name) + with database.get_sqla_engine_with_context() as engine: + schema = inspect(engine).default_schema_name + table_exists = database.has_table_by_name(tbl_name) - if not only_metadata and (not table_exists or force): - url = get_example_url("random_time_series.json.gz") - pdf = pd.read_json(url, compression="gzip") - if database.backend == "presto": - pdf.ds = pd.to_datetime(pdf.ds, unit="s") - pdf.ds = pdf.ds.dt.strftime("%Y-%m-%d %H:%M%:%S") - else: - pdf.ds = pd.to_datetime(pdf.ds, unit="s") + if not only_metadata and (not table_exists or force): + url = get_example_url("random_time_series.json.gz") + pdf = pd.read_json(url, compression="gzip") + if database.backend == "presto": + pdf.ds = pd.to_datetime(pdf.ds, unit="s") + pdf.ds = pdf.ds.dt.strftime("%Y-%m-%d %H:%M%:%S") + else: + pdf.ds = pd.to_datetime(pdf.ds, unit="s") - pdf.to_sql( - tbl_name, - engine, - schema=schema, - if_exists="replace", - chunksize=500, - dtype={"ds": DateTime if database.backend != "presto" else String(255)}, - index=False, - ) + pdf.to_sql( + tbl_name, + engine, + schema=schema, + if_exists="replace", + chunksize=500, + dtype={"ds": DateTime if database.backend != "presto" else String(255)}, + index=False, + ) print("Done loading table!") print("-" * 80) diff --git a/superset/examples/sf_population_polygons.py b/superset/examples/sf_population_polygons.py index 71ba34401af9..6011b82b0965 100644 --- a/superset/examples/sf_population_polygons.py +++ b/superset/examples/sf_population_polygons.py @@ -30,29 +30,29 @@ def load_sf_population_polygons( ) -> None: tbl_name = "sf_population_polygons" database = database_utils.get_example_database() - engine = database.get_sqla_engine() - schema = inspect(engine).default_schema_name - table_exists = database.has_table_by_name(tbl_name) + with database.get_sqla_engine_with_context() as engine: + schema = inspect(engine).default_schema_name + table_exists = database.has_table_by_name(tbl_name) - if not only_metadata and (not table_exists or force): - url = get_example_url("sf_population.json.gz") - df = pd.read_json(url, compression="gzip") - df["contour"] = df.contour.map(json.dumps) + if not only_metadata and (not table_exists or force): + url = get_example_url("sf_population.json.gz") + df = pd.read_json(url, compression="gzip") + df["contour"] = df.contour.map(json.dumps) - df.to_sql( - tbl_name, - engine, - schema=schema, - if_exists="replace", - chunksize=500, - dtype={ - "zipcode": BigInteger, - "population": BigInteger, - "contour": Text, - "area": Float, - }, - index=False, - ) + df.to_sql( + tbl_name, + engine, + schema=schema, + if_exists="replace", + chunksize=500, + dtype={ + "zipcode": BigInteger, + "population": BigInteger, + "contour": Text, + "area": Float, + }, + index=False, + ) print("Creating table {} reference".format(tbl_name)) table = get_table_connector_registry() diff --git a/superset/examples/supported_charts_dashboard.py b/superset/examples/supported_charts_dashboard.py index aa4f404ccb0f..551741bf7d17 100644 --- a/superset/examples/supported_charts_dashboard.py +++ b/superset/examples/supported_charts_dashboard.py @@ -453,11 +453,11 @@ def load_supported_charts_dashboard() -> None: """Loading a dashboard featuring supported charts""" database = get_example_database() - engine = database.get_sqla_engine() - schema = inspect(engine).default_schema_name + with database.get_sqla_engine_with_context() as engine: + schema = inspect(engine).default_schema_name - tbl_name = "birth_names" - table_exists = database.has_table_by_name(tbl_name, schema=schema) + tbl_name = "birth_names" + table_exists = database.has_table_by_name(tbl_name, schema=schema) if table_exists: table = get_table_connector_registry() diff --git a/superset/examples/world_bank.py b/superset/examples/world_bank.py index 4a18f806eae5..b65ad68d1af6 100644 --- a/superset/examples/world_bank.py +++ b/superset/examples/world_bank.py @@ -51,37 +51,38 @@ def load_world_bank_health_n_pop( # pylint: disable=too-many-locals, too-many-s """Loads the world bank health dataset, slices and a dashboard""" tbl_name = "wb_health_population" database = superset.utils.database.get_example_database() - engine = database.get_sqla_engine() - schema = inspect(engine).default_schema_name - table_exists = database.has_table_by_name(tbl_name) + with database.get_sqla_engine_with_context() as engine: - if not only_metadata and (not table_exists or force): - url = get_example_url("countries.json.gz") - pdf = pd.read_json(url, compression="gzip") - pdf.columns = [col.replace(".", "_") for col in pdf.columns] - if database.backend == "presto": - pdf.year = pd.to_datetime(pdf.year) - pdf.year = pdf.year.dt.strftime("%Y-%m-%d %H:%M%:%S") - else: - pdf.year = pd.to_datetime(pdf.year) - pdf = pdf.head(100) if sample else pdf + schema = inspect(engine).default_schema_name + table_exists = database.has_table_by_name(tbl_name) - pdf.to_sql( - tbl_name, - engine, - schema=schema, - if_exists="replace", - chunksize=50, - dtype={ - # TODO(bkyryliuk): use TIMESTAMP type for presto - "year": DateTime if database.backend != "presto" else String(255), - "country_code": String(3), - "country_name": String(255), - "region": String(255), - }, - method="multi", - index=False, - ) + if not only_metadata and (not table_exists or force): + url = get_example_url("countries.json.gz") + pdf = pd.read_json(url, compression="gzip") + pdf.columns = [col.replace(".", "_") for col in pdf.columns] + if database.backend == "presto": + pdf.year = pd.to_datetime(pdf.year) + pdf.year = pdf.year.dt.strftime("%Y-%m-%d %H:%M%:%S") + else: + pdf.year = pd.to_datetime(pdf.year) + pdf = pdf.head(100) if sample else pdf + + pdf.to_sql( + tbl_name, + engine, + schema=schema, + if_exists="replace", + chunksize=50, + dtype={ + # TODO(bkyryliuk): use TIMESTAMP type for presto + "year": DateTime if database.backend != "presto" else String(255), + "country_code": String(3), + "country_name": String(255), + "region": String(255), + }, + method="multi", + index=False, + ) print("Creating table [wb_health_population] reference") table = get_table_connector_registry() From 95d079e1963c6de9b20e67195513fbfb266e8c64 Mon Sep 17 00:00:00 2001 From: hughhhh Date: Wed, 26 Oct 2022 19:52:23 -0400 Subject: [PATCH 12/75] change remaining bits --- superset/models/dashboard.py | 5 ++-- superset/models/filter_set.py | 5 ++-- superset/models/helpers.py | 12 +++++----- superset/sql_validators/presto_db.py | 24 ++++++++++--------- superset/utils/core.py | 4 ++-- superset/utils/mock_data.py | 36 ++++++++++++++-------------- superset/views/core.py | 8 +++---- 7 files changed, 49 insertions(+), 45 deletions(-) diff --git a/superset/models/dashboard.py b/superset/models/dashboard.py index 57567e61641c..a98d76e58162 100644 --- a/superset/models/dashboard.py +++ b/superset/models/dashboard.py @@ -224,8 +224,9 @@ def charts(self) -> List[str]: @property def sqla_metadata(self) -> None: # pylint: disable=no-member - meta = MetaData(bind=self.get_sqla_engine()) - meta.reflect() + with self.get_sqla_engine_with_context() as engine: + meta = MetaData(bind=engine) + meta.reflect() @property def status(self) -> utils.DashboardStatus: diff --git a/superset/models/filter_set.py b/superset/models/filter_set.py index 2d3b218793dc..4bbef264900d 100644 --- a/superset/models/filter_set.py +++ b/superset/models/filter_set.py @@ -55,8 +55,9 @@ def url(self) -> str: @property def sqla_metadata(self) -> None: # pylint: disable=no-member - meta = MetaData(bind=self.get_sqla_engine()) - meta.reflect() + with self.get_sqla_engine_with_context() as engine: + meta = MetaData(bind=engine) + meta.reflect() @property def changed_by_name(self) -> str: diff --git a/superset/models/helpers.py b/superset/models/helpers.py index da526b559c7f..28796383d83d 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -1281,13 +1281,13 @@ def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]: if limit: qry = qry.limit(limit) - engine = self.database.get_sqla_engine() # type: ignore - sql = qry.compile(engine, compile_kwargs={"literal_binds": True}) - sql = self._apply_cte(sql, cte) - sql = self.mutate_query_from_config(sql) + with self.database.get_sqla_engine_with_context() as engine: # typing: ignore + sql = qry.compile(engine, compile_kwargs={"literal_binds": True}) + sql = self._apply_cte(sql, cte) + sql = self.mutate_query_from_config(sql) - df = pd.read_sql_query(sql=sql, con=engine) - return df[column_name].to_list() + df = pd.read_sql_query(sql=sql, con=engine) + return df[column_name].to_list() def get_timestamp_expression( self, diff --git a/superset/sql_validators/presto_db.py b/superset/sql_validators/presto_db.py index 70b324c90073..37375e484dec 100644 --- a/superset/sql_validators/presto_db.py +++ b/superset/sql_validators/presto_db.py @@ -162,16 +162,18 @@ def validate( statements = parsed_query.get_statements() logger.info("Validating %i statement(s)", len(statements)) - engine = database.get_sqla_engine(schema, source=QuerySource.SQL_LAB) - # Sharing a single connection and cursor across the - # execution of all statements (if many) - annotations: List[SQLValidationAnnotation] = [] - with closing(engine.raw_connection()) as conn: - cursor = conn.cursor() - for statement in parsed_query.get_statements(): - annotation = cls.validate_statement(statement, database, cursor) - if annotation: - annotations.append(annotation) - logger.debug("Validation found %i error(s)", len(annotations)) + with database.get_sqla_engine_with_context( + schema, source=QuerySource.SQL_LAB + ) as engine: + # Sharing a single connection and cursor across the + # execution of all statements (if many) + annotations: List[SQLValidationAnnotation] = [] + with closing(engine.raw_connection()) as conn: + cursor = conn.cursor() + for statement in parsed_query.get_statements(): + annotation = cls.validate_statement(statement, database, cursor) + if annotation: + annotations.append(annotation) + logger.debug("Validation found %i error(s)", len(annotations)) return annotations diff --git a/superset/utils/core.py b/superset/utils/core.py index a893696e024f..1cac98df2e49 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -1258,8 +1258,8 @@ def get_example_default_schema() -> Optional[str]: Return the default schema of the examples database, if any. """ database = get_example_database() - engine = database.get_sqla_engine() - return inspect(engine).default_schema_name + with database.get_sqla_engine_with_context() as engine: + return inspect(engine).default_schema_name def backend() -> str: diff --git a/superset/utils/mock_data.py b/superset/utils/mock_data.py index 904f7ee42e88..4b156cc10c10 100644 --- a/superset/utils/mock_data.py +++ b/superset/utils/mock_data.py @@ -187,29 +187,29 @@ def add_data( database = get_example_database() table_exists = database.has_table_by_name(table_name) - engine = database.get_sqla_engine() - if columns is None: - if not table_exists: - raise Exception( - f"The table {table_name} does not exist. To create it you need to " - "pass a list of column names and types." - ) + with database.get_sqla_engine_with_context() as engine: + if columns is None: + if not table_exists: + raise Exception( + f"The table {table_name} does not exist. To create it you need to " + "pass a list of column names and types." + ) - inspector = inspect(engine) - columns = inspector.get_columns(table_name) + inspector = inspect(engine) + columns = inspector.get_columns(table_name) - # create table if needed - column_objects = get_column_objects(columns) - metadata = MetaData() - table = Table(table_name, metadata, *column_objects) - metadata.create_all(engine) + # create table if needed + column_objects = get_column_objects(columns) + metadata = MetaData() + table = Table(table_name, metadata, *column_objects) + metadata.create_all(engine) - if not append: - engine.execute(table.delete()) + if not append: + engine.execute(table.delete()) - data = generate_data(columns, num_rows) - engine.execute(table.insert(), data) + data = generate_data(columns, num_rows) + engine.execute(table.insert(), data) def get_column_objects(columns: List[ColumnInfo]) -> List[Column]: diff --git a/superset/views/core.py b/superset/views/core.py index 60ce1edfd2dc..bece5c1d03e0 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -1378,11 +1378,11 @@ def testconn(self) -> FlaskResponse: ) database.set_sqlalchemy_uri(uri) database.db_engine_spec.mutate_db_for_connection_test(database) - engine = database.get_sqla_engine() - with closing(engine.raw_connection()) as conn: - if engine.dialect.do_ping(conn): - return json_success('"OK"') + with database.get_sqla_engine_with_context() as engine: + with closing(engine.raw_connection()) as conn: + if engine.dialect.do_ping(conn): + return json_success('"OK"') raise DBAPIError(None, None, None) except CertificateException as ex: From d5926e3e157e83527decde3351ca9fde023e70a2 Mon Sep 17 00:00:00 2001 From: hughhhh Date: Thu, 27 Oct 2022 15:03:13 -0400 Subject: [PATCH 13/75] add id --- ...0-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py b/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py index 5bf50f870d1d..deb2a6d44c21 100644 --- a/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py +++ b/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py @@ -50,6 +50,7 @@ def upgrade(): # ImportExportMixin sa.Column("uuid", UUIDType(binary=True), primary_key=False, default=uuid4), # SSHTunnelCredentials + sa.Column("id", sa.Integer(), primary_key=True), sa.Column("database_id", sa.INTEGER(), sa.ForeignKey("dbs.id"), nullable=True), sa.Column("server_address", EncryptedType(sa.String, app_config["SECRET_KEY"])), sa.Column("server_port", EncryptedType(sa.String, app_config["SECRET_KEY"])), @@ -72,7 +73,6 @@ def upgrade(): EncryptedType(sa.String, app_config["SECRET_KEY"]), nullable=True, ), - sa.PrimaryKeyConstraint("id"), ) From f7a6a411f2d4811b3a9841f4f15b49bd61f90fa1 Mon Sep 17 00:00:00 2001 From: hughhhh Date: Thu, 27 Oct 2022 21:49:48 -0400 Subject: [PATCH 14/75] use factory instead --- ...c8595_create_ssh_tunnel_credentials_tbl.py | 21 +++++++------------ 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py b/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py index deb2a6d44c21..b4c57470f71c 100644 --- a/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py +++ b/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py @@ -30,9 +30,10 @@ import sqlalchemy as sa from alembic import op -from sqlalchemy_utils import EncryptedType, UUIDType +from sqlalchemy_utils import UUIDType from superset import app +from superset.extensions import encrypted_field_factory app_config = app.config @@ -51,26 +52,20 @@ def upgrade(): sa.Column("uuid", UUIDType(binary=True), primary_key=False, default=uuid4), # SSHTunnelCredentials sa.Column("id", sa.Integer(), primary_key=True), - sa.Column("database_id", sa.INTEGER(), sa.ForeignKey("dbs.id"), nullable=True), - sa.Column("server_address", EncryptedType(sa.String, app_config["SECRET_KEY"])), - sa.Column("server_port", EncryptedType(sa.String, app_config["SECRET_KEY"])), + sa.Column("database_id", sa.INTEGER(), sa.ForeignKey("dbs.id")), + sa.Column("server_address", encrypted_field_factory.create(sa.String(1024))), + sa.Column("username", encrypted_field_factory.create(sa.String(1024))), sa.Column( - "username", - EncryptedType(sa.String, app_config["SECRET_KEY"]), - ), - sa.Column( - "password", - EncryptedType(sa.String, app_config["SECRET_KEY"]), - nullable=True, + "password", encrypted_field_factory.create(sa.String(1024)), nullable=True ), sa.Column( "private_key", - EncryptedType(sa.String, app_config["SECRET_KEY"]), + encrypted_field_factory.create(sa.String(1024)), nullable=True, ), sa.Column( "private_key_password", - EncryptedType(sa.String, app_config["SECRET_KEY"]), + encrypted_field_factory.create(sa.String(1024)), nullable=True, ), ) From 4146d5a64057e9c6ac366a59a72fecab61f247a1 Mon Sep 17 00:00:00 2001 From: hughhhh Date: Mon, 31 Oct 2022 15:11:43 -0400 Subject: [PATCH 15/75] setup return value for contextmanager --- tests/integration_tests/sqllab_tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration_tests/sqllab_tests.py b/tests/integration_tests/sqllab_tests.py index bee9b08114a4..d3a919eaa961 100644 --- a/tests/integration_tests/sqllab_tests.py +++ b/tests/integration_tests/sqllab_tests.py @@ -733,7 +733,7 @@ def test_execute_sql_statements(self, mock_execute_sql_statement, mock_get_query mock_query = mock.MagicMock() mock_query.database.allow_run_async = False mock_cursor = mock.MagicMock() - mock_query.database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value = ( + mock_query.database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value = ( mock_cursor ) mock_query.database.db_engine_spec.run_multiple_statements_as_one = False From f8b877de6eb2e0f9b3f2c8c74fe300afab8430e2 Mon Sep 17 00:00:00 2001 From: hughhhh Date: Mon, 31 Oct 2022 15:50:33 -0400 Subject: [PATCH 16/75] add sshtunnel pip --- requirements/base.txt | 30 ++++++++++++++++++++---------- requirements/development.txt | 4 +++- requirements/docker.txt | 2 +- requirements/local.txt | 2 +- setup.py | 1 + 5 files changed, 26 insertions(+), 13 deletions(-) diff --git a/requirements/base.txt b/requirements/base.txt index 905f4e1edaa3..6a5b9f7a2e4c 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -14,12 +14,13 @@ amqp==5.1.0 apispec[yaml]==3.3.2 # via flask-appbuilder attrs==21.2.0 - # via - # jsonschema + # via jsonschema babel==2.9.1 # via flask-babel backoff==1.11.1 # via apache-superset +bcrypt==4.0.1 + # via paramiko billiard==3.6.4.0 # via celery bleach==3.3.1 @@ -31,7 +32,9 @@ cachelib==0.4.1 celery==5.2.2 # via apache-superset cffi==1.14.6 - # via cryptography + # via + # cryptography + # pynacl click==8.0.4 # via # apache-superset @@ -58,7 +61,9 @@ cron-descriptor==1.2.24 croniter==1.0.15 # via apache-superset cryptography==3.4.7 - # via apache-superset + # via + # apache-superset + # paramiko deprecation==2.1.0 # via apache-superset dnspython==2.1.0 @@ -115,13 +120,14 @@ gunicorn==20.1.0 # via apache-superset hashids==1.3.1 # via apache-superset +hijri-converter==2.2.4 + # via holidays holidays==0.16.0 # via apache-superset humanize==3.11.0 # via apache-superset idna==3.2 - # via - # email-validator + # via email-validator isodate==0.6.0 # via apache-superset itsdangerous==2.1.1 @@ -169,6 +175,8 @@ packaging==21.3 # deprecation pandas==1.4.4 # via apache-superset +paramiko==2.11.0 + # via sshtunnel parsedatetime==2.6 # via apache-superset pgsanity==0.2.9 @@ -190,6 +198,8 @@ pyjwt==2.4.0 # flask-jwt-extended pymeeus==0.5.11 # via convertdate +pynacl==1.5.0 + # via paramiko pyparsing==3.0.6 # via # apache-superset @@ -214,7 +224,6 @@ pytz==2021.3 # via # babel # celery - # convertdate # flask-babel # pandas pyyaml==5.4.1 @@ -232,16 +241,15 @@ six==1.16.0 # bleach # click-repl # flask-talisman - # holidays # isodate # jsonschema + # paramiko # polyline # prison # pyrsistent # python-dateutil - # sqlalchemy-utils # wtforms-json -slack_sdk==3.18.3 +slack-sdk==3.18.3 # via apache-superset sqlalchemy==1.4.36 # via @@ -257,6 +265,8 @@ sqlalchemy-utils==0.38.3 # flask-appbuilder sqlparse==0.4.3 # via apache-superset +sshtunnel==0.4.0 + # via apache-superset tabulate==0.8.9 # via apache-superset typing-extensions==3.10.0.0 diff --git a/requirements/development.txt b/requirements/development.txt index 5fc11b914cbf..a82d06c14ee3 100644 --- a/requirements/development.txt +++ b/requirements/development.txt @@ -6,7 +6,7 @@ # pip-compile-multi # -r base.txt --e file:. +-e file:///Users/hugh/src/superset # via # -r requirements/base.in # -r requirements/development.in @@ -28,6 +28,8 @@ certifi==2021.10.8 # via requests chardet==4.0.0 # via tabulator +charset-normalizer==2.0.12 + # via requests decorator==5.1.1 # via ipython et-xmlfile==1.1.0 diff --git a/requirements/docker.txt b/requirements/docker.txt index 0c2d36159e4d..d6e8e661d4c6 100644 --- a/requirements/docker.txt +++ b/requirements/docker.txt @@ -6,7 +6,7 @@ # pip-compile-multi # -r base.txt --e file:. +-e file:///Users/hugh/src/superset # via # -r requirements/base.in # -r requirements/docker.in diff --git a/requirements/local.txt b/requirements/local.txt index c4bd3cd599b3..10280f670daf 100644 --- a/requirements/local.txt +++ b/requirements/local.txt @@ -6,7 +6,7 @@ # pip-compile-multi # -r development.txt --e file:. +-e file:///Users/hugh/src/superset # via # -r requirements/base.in # -r requirements/development.in diff --git a/setup.py b/setup.py index 405346e4566b..934530c7d0f0 100644 --- a/setup.py +++ b/setup.py @@ -113,6 +113,7 @@ def get_git_sha() -> str: "PyJWT>=2.4.0, <3.0", "redis", "selenium>=3.141.0", + "sshtunnel>=0.4.0", "simplejson>=3.15.0", "slack_sdk>=3.1.1, <4", "sqlalchemy>=1.4, <2", From 54fc147fdf253d743e4ccd20a2ea295e11c7a69a Mon Sep 17 00:00:00 2001 From: hughhhh Date: Tue, 1 Nov 2022 12:15:44 -0400 Subject: [PATCH 17/75] updates test --- tests/integration_tests/sqllab_tests.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/integration_tests/sqllab_tests.py b/tests/integration_tests/sqllab_tests.py index d3a919eaa961..d196dab5adf3 100644 --- a/tests/integration_tests/sqllab_tests.py +++ b/tests/integration_tests/sqllab_tests.py @@ -786,7 +786,7 @@ def test_execute_sql_statements_no_results_backend( mock_query = mock.MagicMock() mock_query.database.allow_run_async = True mock_cursor = mock.MagicMock() - mock_query.database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value = ( + mock_query.database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value = ( mock_cursor ) mock_query.database.db_engine_spec.run_multiple_statements_as_one = False @@ -836,7 +836,7 @@ def test_execute_sql_statements_ctas( mock_query = mock.MagicMock() mock_query.database.allow_run_async = False mock_cursor = mock.MagicMock() - mock_query.database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value = ( + mock_query.database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value = ( mock_cursor ) mock_query.database.db_engine_spec.run_multiple_statements_as_one = False From fdc6ca39496314bb30a449c4b1e9c8b6195dea81 Mon Sep 17 00:00:00 2001 From: hughhhh Date: Wed, 2 Nov 2022 10:28:11 -0400 Subject: [PATCH 18/75] fix linting --- superset/examples/bart_lines.py | 1 - superset/models/helpers.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/superset/examples/bart_lines.py b/superset/examples/bart_lines.py index eb11be2eed1b..5d167b02d062 100644 --- a/superset/examples/bart_lines.py +++ b/superset/examples/bart_lines.py @@ -30,7 +30,6 @@ def load_bart_lines(only_metadata: bool = False, force: bool = False) -> None: tbl_name = "bart_lines" database = get_example_database() with database.get_sqla_engine_with_context() as engine: - engine = database.get_sqla_engine() schema = inspect(engine).default_schema_name table_exists = database.has_table_by_name(tbl_name) diff --git a/superset/models/helpers.py b/superset/models/helpers.py index 55de3f4be2fe..5da3d209dd24 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -1281,7 +1281,7 @@ def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]: if limit: qry = qry.limit(limit) - with self.database.get_sqla_engine_with_context() as engine: # type: ignore + with self.database.get_sqla_engine_with_context() as engine: # type: ignore sql = qry.compile(engine, compile_kwargs={"literal_binds": True}) sql = self._apply_cte(sql, cte) sql = self.mutate_query_from_config(sql) From 66c0801e494f16ba44d492f3648c1e1a57c8e35a Mon Sep 17 00:00:00 2001 From: hughhhh Date: Thu, 3 Nov 2022 12:39:33 -0400 Subject: [PATCH 19/75] renaming function --- superset/examples/birth_names.py | 6 +- superset/models/core.py | 24 +++---- tests/integration_tests/conftest.py | 63 ++++++++++--------- .../databases/commands_tests.py | 8 +-- .../db_engine_specs/hive_tests.py | 8 ++- .../db_engine_specs/presto_tests.py | 18 +++--- .../fixtures/unicode_dashboard.py | 4 +- .../fixtures/world_bank_dashboard.py | 4 +- tests/integration_tests/model_tests.py | 28 +++------ .../integration_tests/sql_validator_tests.py | 4 +- 10 files changed, 80 insertions(+), 87 deletions(-) diff --git a/superset/examples/birth_names.py b/superset/examples/birth_names.py index 3fc86bc849cc..406a70b2cc4d 100644 --- a/superset/examples/birth_names.py +++ b/superset/examples/birth_names.py @@ -81,7 +81,7 @@ def load_data(tbl_name: str, database: Database, sample: bool = False) -> None: pdf.to_sql( tbl_name, - database.get_sqla_engine(), + engine, schema=schema, if_exists="replace", chunksize=500, @@ -104,8 +104,8 @@ def load_birth_names( ) -> None: """Loading birth name dataset from a zip file in the repo""" database = get_example_database() - engine = database.get_sqla_engine() - schema = inspect(engine).default_schema_name + with database.get_sqla_engine_with_context() as engine: + schema = inspect(engine).default_schema_name tbl_name = "birth_names" table_exists = database.has_table_by_name(tbl_name, schema=schema) diff --git a/superset/models/core.py b/superset/models/core.py index d0a32a1864c2..712d78a506c6 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -370,11 +370,11 @@ def get_sqla_engine_with_context( source: Optional[utils.QuerySource] = None, ) -> Engine: try: - yield self.get_sqla_engine(schema=schema, nullpool=nullpool, source=source) + yield self._get_sqla_engine(schema=schema, nullpool=nullpool, source=source) except Exception as ex: - raise self.db_engine_spec.get_dbapi_mapped_exception(ex) + raise ex - def get_sqla_engine( + def _get_sqla_engine( self, schema: Optional[str] = None, nullpool: bool = True, @@ -392,7 +392,7 @@ def get_sqla_engine( ) masked_url = self.get_password_masked_url(sqlalchemy_url) - logger.debug("Database.get_sqla_engine(). Masked URL: %s", str(masked_url)) + logger.debug("Database._get_sqla_engine(). Masked URL: %s", str(masked_url)) params = extra.get("engine_params", {}) if nullpool: @@ -442,7 +442,7 @@ def get_df( # pylint: disable=too-many-locals mutator: Optional[Callable[[pd.DataFrame], None]] = None, ) -> pd.DataFrame: sqls = self.db_engine_spec.parse_sql(sql) - engine = self.get_sqla_engine(schema) + engine = self._get_sqla_engine(schema) def needs_conversion(df_series: pd.Series) -> bool: return ( @@ -487,7 +487,7 @@ def _log_query(sql: str) -> None: return df def compile_sqla_query(self, qry: Select, schema: Optional[str] = None) -> str: - engine = self.get_sqla_engine(schema=schema) + engine = self._get_sqla_engine(schema=schema) sql = str(qry.compile(engine, compile_kwargs={"literal_binds": True})) @@ -508,7 +508,7 @@ def select_star( # pylint: disable=too-many-arguments cols: Optional[List[Dict[str, Any]]] = None, ) -> str: """Generates a ``select *`` statement in the proper dialect""" - eng = self.get_sqla_engine(schema=schema, source=utils.QuerySource.SQL_LAB) + eng = self._get_sqla_engine(schema=schema, source=utils.QuerySource.SQL_LAB) return self.db_engine_spec.select_star( self, table_name, @@ -533,7 +533,7 @@ def safe_sqlalchemy_uri(self) -> str: @property def inspector(self) -> Inspector: - engine = self.get_sqla_engine() + engine = self._get_sqla_engine() return sqla.inspect(engine) @cache_util.memoized_func( @@ -673,7 +673,7 @@ def get_table(self, table_name: str, schema: Optional[str] = None) -> Table: meta, schema=schema or None, autoload=True, - autoload_with=self.get_sqla_engine(), + autoload_with=self._get_sqla_engine(), ) def get_table_comment( @@ -759,11 +759,11 @@ def get_perm(self) -> str: return self.perm # type: ignore def has_table(self, table: Table) -> bool: - engine = self.get_sqla_engine() + engine = self._get_sqla_engine() return engine.has_table(table.table_name, table.schema or None) def has_table_by_name(self, table_name: str, schema: Optional[str] = None) -> bool: - engine = self.get_sqla_engine() + engine = self._get_sqla_engine() return engine.has_table(table_name, schema) @classmethod @@ -782,7 +782,7 @@ def _has_view( return view_name in view_names def has_view(self, view_name: str, schema: Optional[str] = None) -> bool: - engine = self.get_sqla_engine() + engine = self._get_sqla_engine() return engine.run_callable(self._has_view, engine.dialect, view_name, schema) def has_view_by_name(self, view_name: str, schema: Optional[str] = None) -> bool: diff --git a/tests/integration_tests/conftest.py b/tests/integration_tests/conftest.py index efbc6bf7f07d..8908c3e22782 100644 --- a/tests/integration_tests/conftest.py +++ b/tests/integration_tests/conftest.py @@ -171,7 +171,7 @@ def __call__(self) -> Database: return self._db def _load_lazy_data_to_decouple_from_session(self) -> None: - self._db.get_sqla_engine() # type: ignore + self._db._get_sqla_engine() # type: ignore self._db.backend # type: ignore def remove(self) -> None: @@ -336,37 +336,38 @@ def physical_dataset(): from superset.connectors.sqla.utils import get_identifier_quoter example_database = get_example_database() - engine = example_database.get_sqla_engine() - quoter = get_identifier_quoter(engine.name) - # sqlite can only execute one statement at a time - engine.execute( - f""" - CREATE TABLE IF NOT EXISTS physical_dataset( - col1 INTEGER, - col2 VARCHAR(255), - col3 DECIMAL(4,2), - col4 VARCHAR(255), - col5 TIMESTAMP DEFAULT '1970-01-01 00:00:01', - col6 TIMESTAMP DEFAULT '1970-01-01 00:00:01', - {quoter('time column with spaces')} TIMESTAMP DEFAULT '1970-01-01 00:00:01' - ); - """ - ) - engine.execute( + + with example_database.get_sqla_engine_with_context() as engine: + quoter = get_identifier_quoter(engine.name) + # sqlite can only execute one statement at a time + engine.execute( + f""" + CREATE TABLE IF NOT EXISTS physical_dataset( + col1 INTEGER, + col2 VARCHAR(255), + col3 DECIMAL(4,2), + col4 VARCHAR(255), + col5 TIMESTAMP DEFAULT '1970-01-01 00:00:01', + col6 TIMESTAMP DEFAULT '1970-01-01 00:00:01', + {quoter('time column with spaces')} TIMESTAMP DEFAULT '1970-01-01 00:00:01' + ); + """ + ) + engine.execute( + """ + INSERT INTO physical_dataset values + (0, 'a', 1.0, NULL, '2000-01-01 00:00:00', '2002-01-03 00:00:00', '2002-01-03 00:00:00'), + (1, 'b', 1.1, NULL, '2000-01-02 00:00:00', '2002-02-04 00:00:00', '2002-02-04 00:00:00'), + (2, 'c', 1.2, NULL, '2000-01-03 00:00:00', '2002-03-07 00:00:00', '2002-03-07 00:00:00'), + (3, 'd', 1.3, NULL, '2000-01-04 00:00:00', '2002-04-12 00:00:00', '2002-04-12 00:00:00'), + (4, 'e', 1.4, NULL, '2000-01-05 00:00:00', '2002-05-11 00:00:00', '2002-05-11 00:00:00'), + (5, 'f', 1.5, NULL, '2000-01-06 00:00:00', '2002-06-13 00:00:00', '2002-06-13 00:00:00'), + (6, 'g', 1.6, NULL, '2000-01-07 00:00:00', '2002-07-15 00:00:00', '2002-07-15 00:00:00'), + (7, 'h', 1.7, NULL, '2000-01-08 00:00:00', '2002-08-18 00:00:00', '2002-08-18 00:00:00'), + (8, 'i', 1.8, NULL, '2000-01-09 00:00:00', '2002-09-20 00:00:00', '2002-09-20 00:00:00'), + (9, 'j', 1.9, NULL, '2000-01-10 00:00:00', '2002-10-22 00:00:00', '2002-10-22 00:00:00'); """ - INSERT INTO physical_dataset values - (0, 'a', 1.0, NULL, '2000-01-01 00:00:00', '2002-01-03 00:00:00', '2002-01-03 00:00:00'), - (1, 'b', 1.1, NULL, '2000-01-02 00:00:00', '2002-02-04 00:00:00', '2002-02-04 00:00:00'), - (2, 'c', 1.2, NULL, '2000-01-03 00:00:00', '2002-03-07 00:00:00', '2002-03-07 00:00:00'), - (3, 'd', 1.3, NULL, '2000-01-04 00:00:00', '2002-04-12 00:00:00', '2002-04-12 00:00:00'), - (4, 'e', 1.4, NULL, '2000-01-05 00:00:00', '2002-05-11 00:00:00', '2002-05-11 00:00:00'), - (5, 'f', 1.5, NULL, '2000-01-06 00:00:00', '2002-06-13 00:00:00', '2002-06-13 00:00:00'), - (6, 'g', 1.6, NULL, '2000-01-07 00:00:00', '2002-07-15 00:00:00', '2002-07-15 00:00:00'), - (7, 'h', 1.7, NULL, '2000-01-08 00:00:00', '2002-08-18 00:00:00', '2002-08-18 00:00:00'), - (8, 'i', 1.8, NULL, '2000-01-09 00:00:00', '2002-09-20 00:00:00', '2002-09-20 00:00:00'), - (9, 'j', 1.9, NULL, '2000-01-10 00:00:00', '2002-10-22 00:00:00', '2002-10-22 00:00:00'); - """ - ) + ) dataset = SqlaTable( table_name="physical_dataset", diff --git a/tests/integration_tests/databases/commands_tests.py b/tests/integration_tests/databases/commands_tests.py index 4426fa756ff5..64c9b260c4ab 100644 --- a/tests/integration_tests/databases/commands_tests.py +++ b/tests/integration_tests/databases/commands_tests.py @@ -641,7 +641,7 @@ def test_import_v1_rollback(self, mock_import_dataset): class TestTestConnectionDatabaseCommand(SupersetTestCase): - @mock.patch("superset.databases.dao.Database.get_sqla_engine") + @mock.patch("superset.databases.dao.Database._get_sqla_engine") @mock.patch( "superset.databases.commands.test_connection.event_logger.log_with_context" ) @@ -664,7 +664,7 @@ def test_connection_db_exception( ) mock_event_logger.assert_called() - @mock.patch("superset.databases.dao.Database.get_sqla_engine") + @mock.patch("superset.databases.dao.Database._get_sqla_engine") @mock.patch( "superset.databases.commands.test_connection.event_logger.log_with_context" ) @@ -713,7 +713,7 @@ def test_connection_do_ping_timeout( == SupersetErrorType.CONNECTION_DATABASE_TIMEOUT ) - @mock.patch("superset.databases.dao.Database.get_sqla_engine") + @mock.patch("superset.databases.dao.Database._get_sqla_engine") @mock.patch( "superset.databases.commands.test_connection.event_logger.log_with_context" ) @@ -738,7 +738,7 @@ def test_connection_superset_security_connection( mock_event_logger.assert_called() - @mock.patch("superset.databases.dao.Database.get_sqla_engine") + @mock.patch("superset.databases.dao.Database._get_sqla_engine") @mock.patch( "superset.databases.commands.test_connection.event_logger.log_with_context" ) diff --git a/tests/integration_tests/db_engine_specs/hive_tests.py b/tests/integration_tests/db_engine_specs/hive_tests.py index ad80f8397ffe..7fc96e23f7ce 100644 --- a/tests/integration_tests/db_engine_specs/hive_tests.py +++ b/tests/integration_tests/db_engine_specs/hive_tests.py @@ -208,7 +208,9 @@ def test_df_to_sql_if_exists_replace(mock_upload_to_s3, mock_g): mock_database = mock.MagicMock() mock_database.get_df.return_value.empty = False mock_execute = mock.MagicMock(return_value=True) - mock_database.get_sqla_engine.return_value.execute = mock_execute + mock_database.get_sqla_engine_with_context.return_value.__enter__.return_value.execute = ( + mock_execute + ) table_name = "foobar" with app.app_context(): @@ -233,7 +235,9 @@ def test_df_to_sql_if_exists_replace_with_schema(mock_upload_to_s3, mock_g): mock_database = mock.MagicMock() mock_database.get_df.return_value.empty = False mock_execute = mock.MagicMock(return_value=True) - mock_database.get_sqla_engine.return_value.execute = mock_execute + mock_database.get_sqla_engine_with_context.return_value.__enter__.return_value.execute = ( + mock_execute + ) table_name = "foobar" schema = "schema" diff --git a/tests/integration_tests/db_engine_specs/presto_tests.py b/tests/integration_tests/db_engine_specs/presto_tests.py index 2d6cf7b8622c..f439e09288be 100644 --- a/tests/integration_tests/db_engine_specs/presto_tests.py +++ b/tests/integration_tests/db_engine_specs/presto_tests.py @@ -46,10 +46,10 @@ def test_get_view_names(self, mock_is_feature_enabled): mock_execute = mock.MagicMock() mock_fetchall = mock.MagicMock(return_value=[["a", "b,", "c"], ["d", "e"]]) database = mock.MagicMock() - database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.execute = ( + database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.execute = ( mock_execute ) - database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.fetchall = ( + database.get_sqla_engine_with_context.return_value.__enter__.raw_connection.return_value.cursor.return_value.fetchall = ( mock_fetchall ) result = PrestoEngineSpec.get_view_names(database, mock.Mock(), None) @@ -64,10 +64,10 @@ def test_get_view_names_with_schema(self, mock_is_feature_enabled): mock_execute = mock.MagicMock() mock_fetchall = mock.MagicMock(return_value=[["a", "b,", "c"], ["d", "e"]]) database = mock.MagicMock() - database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.execute = ( + database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.execute = ( mock_execute ) - database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.fetchall = ( + database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.fetchall = ( mock_fetchall ) schema = "schema" @@ -855,13 +855,13 @@ def test_get_create_view(self): mock_execute = mock.MagicMock() mock_fetchall = mock.MagicMock(return_value=[["a", "b,", "c"], ["d", "e"]]) database = mock.MagicMock() - database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.execute = ( + database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.execute = ( mock_execute ) - database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.fetchall = ( + database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.fetchall = ( mock_fetchall ) - database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.poll.return_value = ( + database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.poll.return_value = ( False ) schema = "schema" @@ -873,7 +873,7 @@ def test_get_create_view(self): def test_get_create_view_exception(self): mock_execute = mock.MagicMock(side_effect=Exception()) database = mock.MagicMock() - database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.execute = ( + database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.execute = ( mock_execute ) schema = "schema" @@ -886,7 +886,7 @@ def test_get_create_view_database_error(self): mock_execute = mock.MagicMock(side_effect=DatabaseError()) database = mock.MagicMock() - database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.execute = ( + database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.execute = ( mock_execute ) schema = "schema" diff --git a/tests/integration_tests/fixtures/unicode_dashboard.py b/tests/integration_tests/fixtures/unicode_dashboard.py index 9368df7614a9..78178bcde755 100644 --- a/tests/integration_tests/fixtures/unicode_dashboard.py +++ b/tests/integration_tests/fixtures/unicode_dashboard.py @@ -51,8 +51,8 @@ def load_unicode_data(): yield with app.app_context(): - engine = get_example_database().get_sqla_engine() - engine.execute("DROP TABLE IF EXISTS unicode_test") + with get_example_database().get_sqla_engine_with_context() as engine: + engine.execute("DROP TABLE IF EXISTS unicode_test") @pytest.fixture() diff --git a/tests/integration_tests/fixtures/world_bank_dashboard.py b/tests/integration_tests/fixtures/world_bank_dashboard.py index e29962a8c978..561bbe10b270 100644 --- a/tests/integration_tests/fixtures/world_bank_dashboard.py +++ b/tests/integration_tests/fixtures/world_bank_dashboard.py @@ -64,8 +64,8 @@ def load_world_bank_data(): yield with app.app_context(): - engine = get_example_database().get_sqla_engine() - engine.execute("DROP TABLE IF EXISTS wb_health_population") + with get_example_database().get_sqla_engine_with_context() as engine: + engine.execute("DROP TABLE IF EXISTS wb_health_population") @pytest.fixture() diff --git a/tests/integration_tests/model_tests.py b/tests/integration_tests/model_tests.py index 3e13664b63e3..f187eadfbb27 100644 --- a/tests/integration_tests/model_tests.py +++ b/tests/integration_tests/model_tests.py @@ -164,7 +164,7 @@ def test_impersonate_user_presto(self, mocked_create_engine): database_name="test_database", sqlalchemy_uri=uri, extra=extra ) model.impersonate_user = True - model.get_sqla_engine() + model._get_sqla_engine() call_args = mocked_create_engine.call_args assert str(call_args[0][0]) == "presto://gamma@localhost" @@ -177,7 +177,7 @@ def test_impersonate_user_presto(self, mocked_create_engine): } model.impersonate_user = False - model.get_sqla_engine() + model._get_sqla_engine() call_args = mocked_create_engine.call_args assert str(call_args[0][0]) == "presto://localhost" @@ -197,7 +197,7 @@ def test_impersonate_user_trino(self, mocked_create_engine): database_name="test_database", sqlalchemy_uri="trino://localhost" ) model.impersonate_user = True - model.get_sqla_engine() + model._get_sqla_engine() call_args = mocked_create_engine.call_args assert str(call_args[0][0]) == "trino://localhost" @@ -209,7 +209,7 @@ def test_impersonate_user_trino(self, mocked_create_engine): ) model.impersonate_user = True - model.get_sqla_engine() + model._get_sqla_engine() call_args = mocked_create_engine.call_args assert ( @@ -242,7 +242,7 @@ def test_impersonate_user_hive(self, mocked_create_engine): database_name="test_database", sqlalchemy_uri=uri, extra=extra ) model.impersonate_user = True - model.get_sqla_engine() + model._get_sqla_engine() call_args = mocked_create_engine.call_args assert str(call_args[0][0]) == "hive://localhost" @@ -255,7 +255,7 @@ def test_impersonate_user_hive(self, mocked_create_engine): } model.impersonate_user = False - model.get_sqla_engine() + model._get_sqla_engine() call_args = mocked_create_engine.call_args assert str(call_args[0][0]) == "hive://localhost" @@ -380,21 +380,7 @@ def test_get_sqla_engine(self, mocked_create_engine): ) mocked_create_engine.side_effect = Exception() with self.assertRaises(SupersetException): - model.get_sqla_engine() - - # todo(hughhh): update this test - # @mock.patch("superset.models.core.create_engine") - # def test_get_sqla_engine_with_context(self, mocked_create_engine): - # model = Database( - # database_name="test_database", - # sqlalchemy_uri="mysql://root@localhost", - # ) - # model.db_engine_spec.get_dbapi_exception_mapping = mock.Mock( - # return_value={Exception: SupersetException} - # ) - # mocked_create_engine.side_effect = Exception() - # with self.assertRaises(SupersetException): - # model.get_sqla_engine() + model._get_sqla_engine() class TestSqlaTableModel(SupersetTestCase): diff --git a/tests/integration_tests/sql_validator_tests.py b/tests/integration_tests/sql_validator_tests.py index ff4c74fa45fb..d2f6e7108d42 100644 --- a/tests/integration_tests/sql_validator_tests.py +++ b/tests/integration_tests/sql_validator_tests.py @@ -174,7 +174,9 @@ class TestPrestoValidator(SupersetTestCase): def setUp(self): self.validator = PrestoDBSQLValidator self.database = MagicMock() - self.database_engine = self.database.get_sqla_engine.return_value + self.database_engine = ( + self.database.get_sqla_engine_with_context.return_value.__enter__.return_value + ) self.database_conn = self.database_engine.raw_connection.return_value self.database_cursor = self.database_conn.cursor.return_value self.database_cursor.poll.return_value = None From 1f9ec5e8ae431509c23ad5829a22c1962d407659 Mon Sep 17 00:00:00 2001 From: hughhhh Date: Sat, 5 Nov 2022 14:56:40 -0400 Subject: [PATCH 20/75] fix test --- tests/integration_tests/db_engine_specs/presto_tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration_tests/db_engine_specs/presto_tests.py b/tests/integration_tests/db_engine_specs/presto_tests.py index f439e09288be..2363b1c8741b 100644 --- a/tests/integration_tests/db_engine_specs/presto_tests.py +++ b/tests/integration_tests/db_engine_specs/presto_tests.py @@ -49,7 +49,7 @@ def test_get_view_names(self, mock_is_feature_enabled): database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.execute = ( mock_execute ) - database.get_sqla_engine_with_context.return_value.__enter__.raw_connection.return_value.cursor.return_value.fetchall = ( + database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.fetchall = ( mock_fetchall ) result = PrestoEngineSpec.get_view_names(database, mock.Mock(), None) From 41bd19b7522663f80e8cc970683849d7ee5b3d6b Mon Sep 17 00:00:00 2001 From: hughhhh Date: Mon, 7 Nov 2022 12:29:21 -0500 Subject: [PATCH 21/75] add schema to test_connection api --- superset/databases/schemas.py | 44 +++++++++++++++++++---------------- 1 file changed, 24 insertions(+), 20 deletions(-) diff --git a/superset/databases/schemas.py b/superset/databases/schemas.py index 6b5d0746f657..250321e12b22 100644 --- a/superset/databases/schemas.py +++ b/superset/databases/schemas.py @@ -456,6 +456,26 @@ class Meta: # pylint: disable=too-few-public-methods external_url = fields.String(allow_none=True) +class DatabaseSSHTunnelCredentials(Schema): + id = fields.Integer() + database_id = fields.Integer() + + server_address = fields.String() + server_port = fields.Integer() + username = fields.String() + + # Basic Authentication + password = fields.String(required=False) + + # password protected private key authentication + private_key = fields.String(required=False) + private_key_password = fields.String(required=False) + + # remote binding port + bind_host = fields.String() + bind_port = fields.Integer() + + class DatabaseTestConnectionSchema(Schema, DatabaseParametersSchemaMixin): rename_encrypted_extra = pre_load(rename_encrypted_extra) @@ -482,6 +502,10 @@ class DatabaseTestConnectionSchema(Schema, DatabaseParametersSchemaMixin): validate=[Length(1, 1024), sqlalchemy_uri_validator], ) + ssh_tunnel_credentials = fields.Nested( + DatabaseSSHTunnelCredentials, allow_none=True + ) + class TableMetadataOptionsResponseSchema(Schema): deferrable = fields.Bool() @@ -700,26 +724,6 @@ class EncryptedDict(EncryptedField, fields.Dict): pass -class DatabaseSSHTunnelCredentials(Schema): - id = fields.Integer() - database_id = fields.Integer() - - server_address = fields.String() - server_port = fields.Integer() - username = fields.String() - - # Basic Authentication - password = fields.String(required=False) - - # password protected private key authentication - private_key = fields.String(required=False) - private_key_password = fields.String(required=False) - - # remote binding port - bind_host = fields.String() - bind_port = fields.Integer() - - def encrypted_field_properties(self, field: Any, **_) -> Dict[str, Any]: # type: ignore ret = {} if isinstance(field, EncryptedField): From 8811a99fd8897fd5eb04c0178e4868e6f75da4d5 Mon Sep 17 00:00:00 2001 From: hughhhh Date: Mon, 7 Nov 2022 14:01:30 -0500 Subject: [PATCH 22/75] fix get engine to return contextmanager --- superset/db_engine_specs/base.py | 30 +++++++++++------------ superset/db_engine_specs/bigquery.py | 8 +++++-- superset/db_engine_specs/gsheets.py | 11 +++++---- superset/db_engine_specs/hive.py | 5 ++-- superset/db_engine_specs/presto.py | 32 ++++++++++++------------- tests/integration_tests/celery_tests.py | 3 ++- 6 files changed, 47 insertions(+), 42 deletions(-) diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 9dd7594dc720..0db3aef42dc4 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -18,7 +18,7 @@ import json import logging import re -from contextlib import closing +from contextlib import closing, contextmanager from datetime import datetime from typing import ( Any, @@ -472,10 +472,8 @@ def get_engine( schema: Optional[str] = None, source: Optional[utils.QuerySource] = None, ) -> Engine: - with database.get_sqla_engine_with_context( - schema=schema, source=source - ) as engine: - return engine + # this function now returns a context manager associated with the base class + return database.get_sqla_engine_with_context(schema=schema, source=source) @classmethod def get_timestamp_expr( @@ -897,17 +895,17 @@ def df_to_sql( :param to_sql_kwargs: The kwargs to be passed to pandas.DataFrame.to_sql` method """ - engine = cls.get_engine(database) to_sql_kwargs["name"] = table.table if table.schema: # Only add schema when it is preset and non empty. to_sql_kwargs["schema"] = table.schema - if engine.dialect.supports_multivalues_insert: - to_sql_kwargs["method"] = "multi" + with cls.get_engine(database) as engine: + if engine.dialect.supports_multivalues_insert: + to_sql_kwargs["method"] = "multi" - df.to_sql(con=engine, **to_sql_kwargs) + df.to_sql(con=engine, **to_sql_kwargs) @classmethod def convert_dttm( # pylint: disable=unused-argument @@ -1264,13 +1262,15 @@ def estimate_query_cost( parsed_query = sql_parse.ParsedQuery(sql) statements = parsed_query.get_statements() - engine = cls.get_engine(database, schema=schema, source=source) costs = [] - with closing(engine.raw_connection()) as conn: - cursor = conn.cursor() - for statement in statements: - processed_statement = cls.process_statement(statement, database) - costs.append(cls.estimate_statement_cost(processed_statement, cursor)) + with cls.get_engine(database, schema=schema, source=source) as engine: + with closing(engine.raw_connection()) as conn: + cursor = conn.cursor() + for statement in statements: + processed_statement = cls.process_statement(statement, database) + costs.append( + cls.estimate_statement_cost(processed_statement, cursor) + ) return costs @classmethod diff --git a/superset/db_engine_specs/bigquery.py b/superset/db_engine_specs/bigquery.py index 59ed8e1bc9e1..46dd003b208a 100644 --- a/superset/db_engine_specs/bigquery.py +++ b/superset/db_engine_specs/bigquery.py @@ -340,8 +340,12 @@ def df_to_sql( if not table.schema: raise Exception("The table schema must be defined") - engine = cls.get_engine(database) - to_gbq_kwargs = {"destination_table": str(table), "project_id": engine.url.host} + to_gbq_kwargs = {} + with cls.get_engine(database) as engine: + to_gbq_kwargs = { + "destination_table": str(table), + "project_id": engine.url.host, + } # Add credentials if they are set on the SQLAlchemy dialect. creds = engine.dialect.credentials_info diff --git a/superset/db_engine_specs/gsheets.py b/superset/db_engine_specs/gsheets.py index fd1a2754d76b..83f1a8c8f5d1 100644 --- a/superset/db_engine_specs/gsheets.py +++ b/superset/db_engine_specs/gsheets.py @@ -17,6 +17,7 @@ import json import re from contextlib import closing +from msilib import schema from typing import Any, Dict, List, Optional, Pattern, Tuple, TYPE_CHECKING from apispec import APISpec @@ -109,11 +110,11 @@ def extra_table_metadata( table_name: str, schema_name: Optional[str], ) -> Dict[str, Any]: - engine = cls.get_engine(database, schema=schema_name) - with closing(engine.raw_connection()) as conn: - cursor = conn.cursor() - cursor.execute(f'SELECT GET_METADATA("{table_name}")') - results = cursor.fetchone()[0] + with cls.get_engine(database, schema=schema_name) as engine: + with closing(engine.raw_connection()) as conn: + cursor = conn.cursor() + cursor.execute(f'SELECT GET_METADATA("{table_name}")') + results = cursor.fetchone()[0] try: metadata = json.loads(results) diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py index b37348e911ec..3c541c357ea5 100644 --- a/superset/db_engine_specs/hive.py +++ b/superset/db_engine_specs/hive.py @@ -185,8 +185,6 @@ def df_to_sql( :param to_sql_kwargs: The kwargs to be passed to pandas.DataFrame.to_sql` method """ - engine = cls.get_engine(database) - if to_sql_kwargs["if_exists"] == "append": raise SupersetException("Append operation not currently supported") @@ -205,7 +203,8 @@ def df_to_sql( if table_exists: raise SupersetException("Table already exists") elif to_sql_kwargs["if_exists"] == "replace": - engine.execute(f"DROP TABLE IF EXISTS {str(table)}") + with cls.get_engine(database) as engine: + engine.execute(f"DROP TABLE IF EXISTS {str(table)}") def _get_hive_type(dtype: np.dtype) -> str: hive_type_by_dtype = { diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index 22e4f7594ccf..f8f5e1bc18b7 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -425,11 +425,11 @@ def get_view_names( sql = "SELECT table_name FROM information_schema.views" params = {} - engine = cls.get_engine(database, schema=schema) - with closing(engine.raw_connection()) as conn: - cursor = conn.cursor() - cursor.execute(sql, params) - results = cursor.fetchall() + with cls.get_engine(database, schema=schema) as engine: + with closing(engine.raw_connection()) as conn: + cursor = conn.cursor() + cursor.execute(sql, params) + results = cursor.fetchall() return [row[0] for row in results] @@ -951,17 +951,17 @@ def get_create_view( # pylint: disable=import-outside-toplevel from pyhive.exc import DatabaseError - engine = cls.get_engine(database, schema) - with closing(engine.raw_connection()) as conn: - cursor = conn.cursor() - sql = f"SHOW CREATE VIEW {schema}.{table}" - try: - cls.execute(cursor, sql) - - except DatabaseError: # not a VIEW - return None - rows = cls.fetch_data(cursor, 1) - return rows[0][0] + with cls.get_engine(database, schema=schema) as engine: + with closing(engine.raw_connection()) as conn: + cursor = conn.cursor() + sql = f"SHOW CREATE VIEW {schema}.{table}" + try: + cls.execute(cursor, sql) + + except DatabaseError: # not a VIEW + return None + rows = cls.fetch_data(cursor, 1) + return rows[0][0] @classmethod def get_tracking_url(cls, cursor: "Cursor") -> Optional[str]: diff --git a/tests/integration_tests/celery_tests.py b/tests/integration_tests/celery_tests.py index da6db727e711..4eaa1746b1e6 100644 --- a/tests/integration_tests/celery_tests.py +++ b/tests/integration_tests/celery_tests.py @@ -490,7 +490,8 @@ def my_task(): def delete_tmp_view_or_table(name: str, db_object_type: str): - db.get_engine().execute(f"DROP {db_object_type} IF EXISTS {name}") + with db.get_sqla_engine_with_context() as engine: + engine.execute(f"DROP {db_object_type} IF EXISTS {name}") def wait_for_success(result): From 82d7532a8fb0de4e5e953d1ea2b533b9d4c24d9e Mon Sep 17 00:00:00 2001 From: hughhhh Date: Mon, 7 Nov 2022 14:54:10 -0500 Subject: [PATCH 23/75] why --- superset/db_engine_specs/gsheets.py | 1 - 1 file changed, 1 deletion(-) diff --git a/superset/db_engine_specs/gsheets.py b/superset/db_engine_specs/gsheets.py index 83f1a8c8f5d1..805a7ee400cf 100644 --- a/superset/db_engine_specs/gsheets.py +++ b/superset/db_engine_specs/gsheets.py @@ -17,7 +17,6 @@ import json import re from contextlib import closing -from msilib import schema from typing import Any, Dict, List, Optional, Pattern, Tuple, TYPE_CHECKING from apispec import APISpec From 1f829ac321f2a9fc22e795a62af4a787b6241970 Mon Sep 17 00:00:00 2001 From: hughhhh Date: Mon, 7 Nov 2022 15:06:32 -0500 Subject: [PATCH 24/75] yerp --- tests/integration_tests/celery_tests.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/integration_tests/celery_tests.py b/tests/integration_tests/celery_tests.py index 4eaa1746b1e6..da6db727e711 100644 --- a/tests/integration_tests/celery_tests.py +++ b/tests/integration_tests/celery_tests.py @@ -490,8 +490,7 @@ def my_task(): def delete_tmp_view_or_table(name: str, db_object_type: str): - with db.get_sqla_engine_with_context() as engine: - engine.execute(f"DROP {db_object_type} IF EXISTS {name}") + db.get_engine().execute(f"DROP {db_object_type} IF EXISTS {name}") def wait_for_success(result): From d53d116b4201c9a601f254c488054b457a95e28e Mon Sep 17 00:00:00 2001 From: hughhhh Date: Mon, 7 Nov 2022 17:14:43 -0500 Subject: [PATCH 25/75] update typing --- tests/integration_tests/db_engine_specs/bigquery_tests.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/integration_tests/db_engine_specs/bigquery_tests.py b/tests/integration_tests/db_engine_specs/bigquery_tests.py index 1825f9587cc8..4e09077f5cce 100644 --- a/tests/integration_tests/db_engine_specs/bigquery_tests.py +++ b/tests/integration_tests/db_engine_specs/bigquery_tests.py @@ -227,8 +227,10 @@ def test_df_to_sql(self, mock_get_engine): return_value="account_info" ) - mock_get_engine.return_value.url.host = "google-host" - mock_get_engine.return_value.dialect.credentials_info = "secrets" + mock_get_engine.return_value.__enter__.return_value.url.host = "google-host" + mock_get_engine.return_value.__enter__.return_value.dialect.credentials_info = ( + "secrets" + ) BigQueryEngineSpec.df_to_sql( database=database, From 752161d70636143803a0cc7b93ab6e4d77777e6b Mon Sep 17 00:00:00 2001 From: hughhhh Date: Mon, 7 Nov 2022 17:25:34 -0500 Subject: [PATCH 26/75] update comment --- superset/db_engine_specs/base.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 0db3aef42dc4..e992788e98ae 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -23,6 +23,7 @@ from typing import ( Any, Callable, + ContextManager, Dict, List, Match, @@ -471,8 +472,15 @@ def get_engine( database: "Database", schema: Optional[str] = None, source: Optional[utils.QuerySource] = None, - ) -> Engine: - # this function now returns a context manager associated with the base class + ) -> ContextManager[Engine]: + """ + Return an engine context manager. + + >>> with DBEngineSpec.get_engine(database, schema, source) as engine: + ... connection = engine.connect() + ... connection.execute(sql) + + """ return database.get_sqla_engine_with_context(schema=schema, source=source) @classmethod From 1a19a97707dc3e523446fb66276bf28b74e14079 Mon Sep 17 00:00:00 2001 From: hughhhh Date: Tue, 8 Nov 2022 10:22:16 -0500 Subject: [PATCH 27/75] save --- superset/databases/commands/test_connection.py | 9 ++++++++- superset/models/core.py | 15 ++++++++++++++- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/superset/databases/commands/test_connection.py b/superset/databases/commands/test_connection.py index 115a629a1d5f..62923ec0c36f 100644 --- a/superset/databases/commands/test_connection.py +++ b/superset/databases/commands/test_connection.py @@ -89,6 +89,13 @@ def run(self) -> None: # pylint: disable=too-many-statements database.set_sqlalchemy_uri(uri) database.db_engine_spec.mutate_db_for_connection_test(database) + + from superset.databases.models import SSHTunnelCredentials + + if self._properties.get('ssh_tunnel_credentials'): + ssh_tunnel_credentials = SSHTunnelCredentials( + **self._properties['ssh_tunnel_credentials'] + ) event_logger.log_with_context( action="test_connection_attempt", @@ -99,7 +106,7 @@ def ping(engine: Engine) -> bool: with closing(engine.raw_connection()) as conn: return engine.dialect.do_ping(conn) - with database.get_sqla_engine_with_context() as engine: + with database.get_sqla_engine_with_context(ssh_tunnel_credentials=ssh_tunnel_credentials) as engine: try: alive = func_timeout( int( diff --git a/superset/models/core.py b/superset/models/core.py index 712d78a506c6..08564db740ba 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -55,6 +55,7 @@ from superset import app, db_engine_specs from superset.constants import PASSWORD_MASK +from superset.databases.models import SSHTunnelCredentials from superset.databases.utils import make_url_safe from superset.db_engine_specs.base import MetricType, TimeGrain from superset.extensions import cache_manager, encrypted_field_factory, security_manager @@ -368,17 +369,27 @@ def get_sqla_engine_with_context( schema: Optional[str] = None, nullpool: bool = True, source: Optional[utils.QuerySource] = None, + ssh_tunnel_credentials: Optional[SSHTunnelCredentials] = None ) -> Engine: + if ssh_tunnel_credentials: + # build with override + print('building with params') + else: + # do look up in table for using database_id + print('doing look up on table') + try: - yield self._get_sqla_engine(schema=schema, nullpool=nullpool, source=source) + yield self._get_sqla_engine(schema=schema, nullpool=nullpool, source=source, ssh_tunnel_credentials=ssh_tunnel_credentials) except Exception as ex: raise ex + import sshtunnel def _get_sqla_engine( self, schema: Optional[str] = None, nullpool: bool = True, source: Optional[utils.QuerySource] = None, + ssh_tunnel_server: Optional[sshtunnel.SSHTunnelForwarder] = None, ) -> Engine: extra = self.get_extra() sqlalchemy_url = make_url_safe(self.sqlalchemy_uri_decrypted) @@ -422,6 +433,8 @@ def _get_sqla_engine( sqlalchemy_url, params, effective_username, security_manager, source ) + if ssh_tunnel_server: + # update sqlalchemy_url try: return create_engine(sqlalchemy_url, **params) except Exception as ex: From 58b9cce3de87da5186ac22cf362d8df41924c711 Mon Sep 17 00:00:00 2001 From: hughhhh Date: Tue, 8 Nov 2022 11:14:08 -0500 Subject: [PATCH 28/75] save --- superset/databases/commands/test_connection.py | 12 +++++++----- superset/models/core.py | 2 +- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/superset/databases/commands/test_connection.py b/superset/databases/commands/test_connection.py index 62923ec0c36f..015a510127ca 100644 --- a/superset/databases/commands/test_connection.py +++ b/superset/databases/commands/test_connection.py @@ -89,12 +89,12 @@ def run(self) -> None: # pylint: disable=too-many-statements database.set_sqlalchemy_uri(uri) database.db_engine_spec.mutate_db_for_connection_test(database) - + from superset.databases.models import SSHTunnelCredentials - - if self._properties.get('ssh_tunnel_credentials'): + + if self._properties.get("ssh_tunnel_credentials"): ssh_tunnel_credentials = SSHTunnelCredentials( - **self._properties['ssh_tunnel_credentials'] + **self._properties["ssh_tunnel_credentials"] ) event_logger.log_with_context( @@ -106,7 +106,9 @@ def ping(engine: Engine) -> bool: with closing(engine.raw_connection()) as conn: return engine.dialect.do_ping(conn) - with database.get_sqla_engine_with_context(ssh_tunnel_credentials=ssh_tunnel_credentials) as engine: + with database.get_sqla_engine_with_context( + ssh_tunnel_credentials=ssh_tunnel_credentials + ) as engine: try: alive = func_timeout( int( diff --git a/superset/models/core.py b/superset/models/core.py index 08564db740ba..fad91be80503 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -372,7 +372,7 @@ def get_sqla_engine_with_context( ssh_tunnel_credentials: Optional[SSHTunnelCredentials] = None ) -> Engine: if ssh_tunnel_credentials: - # build with override + # build with override print('building with params') else: # do look up in table for using database_id From 31f3c1da8a4cf1f2d6307dd281046d2189a092ac Mon Sep 17 00:00:00 2001 From: hughhhh Date: Tue, 8 Nov 2022 12:30:00 -0500 Subject: [PATCH 29/75] fix pylint --- superset/db_engine_specs/druid.py | 2 +- superset/models/core.py | 4 ++-- superset/views/core.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/superset/db_engine_specs/druid.py b/superset/db_engine_specs/druid.py index 1484429bafbc..0fa47eeb5400 100644 --- a/superset/db_engine_specs/druid.py +++ b/superset/db_engine_specs/druid.py @@ -140,7 +140,7 @@ def get_columns( @classmethod def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]: - # pylint: disable=import-error,import-outside-toplevel + # pylint: disable=import-outside-toplevel from requests import exceptions as requests_exceptions return { diff --git a/superset/models/core.py b/superset/models/core.py index 58e45bc1ef2d..1eb97b262a63 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -563,7 +563,7 @@ def get_all_table_names_in_schema( # pylint: disable=unused-argument database=self, inspector=self.inspector, schema=schema ) return [(table, schema) for table in tables] - except Exception as ex: # pylint: disable=broad-except + except Exception as ex: raise self.db_engine_spec.get_dbapi_mapped_exception(ex) @cache_util.memoized_func( @@ -593,7 +593,7 @@ def get_all_view_names_in_schema( # pylint: disable=unused-argument database=self, inspector=self.inspector, schema=schema ) return [(view, schema) for view in views] - except Exception as ex: # pylint: disable=broad-except + except Exception as ex: raise self.db_engine_spec.get_dbapi_mapped_exception(ex) @cache_util.memoized_func( diff --git a/superset/views/core.py b/superset/views/core.py index b4d7c07be5ea..3252bb92371a 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -327,7 +327,7 @@ def request_access(self) -> FlaskResponse: @has_access @event_logger.log_this @expose("/approve", methods=["POST"]) - def approve(self) -> FlaskResponse: # pylint: disable=too-many-locals,no-self-use + def approve(self) -> FlaskResponse: # pylint: disable=too-many-locals def clean_fulfilled_requests(session: Session) -> None: for dar in session.query(DAR).all(): datasource = DatasourceDAO.get_datasource( From e089a8d2e71907dcadf8eb6b9d295022ba9d5e8c Mon Sep 17 00:00:00 2001 From: hughhhh Date: Tue, 8 Nov 2022 12:51:34 -0500 Subject: [PATCH 30/75] last one --- superset/db_engine_specs/base.py | 2 +- superset/db_engine_specs/trino.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index e47140f38485..d0ab4eb8bb87 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -18,7 +18,7 @@ import json import logging import re -from contextlib import closing, contextmanager +from contextlib import closing from datetime import datetime from typing import ( Any, diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py index 490a36ab2bbc..c62a9d58b434 100644 --- a/superset/db_engine_specs/trino.py +++ b/superset/db_engine_specs/trino.py @@ -224,7 +224,7 @@ def update_params_from_encrypted_extra( @classmethod def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]: - # pylint: disable=import-error,import-outside-toplevel + # pylint: disable=import-outside-toplevel from requests import exceptions as requests_exceptions return { From 45686b78a689bd9e7697b18012f9e4f97536eb45 Mon Sep 17 00:00:00 2001 From: hughhhh Date: Wed, 9 Nov 2022 18:11:27 -0500 Subject: [PATCH 31/75] update naming on ssh tunnel --- .../databases/commands/test_connection.py | 4 ++-- superset/databases/models.py | 10 ++++------ superset/databases/schemas.py | 6 ++---- ...c8595_create_ssh_tunnel_credentials_tbl.py | 2 +- superset/models/core.py | 20 +++++++++++-------- 5 files changed, 21 insertions(+), 21 deletions(-) diff --git a/superset/databases/commands/test_connection.py b/superset/databases/commands/test_connection.py index 015a510127ca..fca146a9fa1b 100644 --- a/superset/databases/commands/test_connection.py +++ b/superset/databases/commands/test_connection.py @@ -90,10 +90,10 @@ def run(self) -> None: # pylint: disable=too-many-statements database.set_sqlalchemy_uri(uri) database.db_engine_spec.mutate_db_for_connection_test(database) - from superset.databases.models import SSHTunnelCredentials + from superset.databases.models import SSHTunnel if self._properties.get("ssh_tunnel_credentials"): - ssh_tunnel_credentials = SSHTunnelCredentials( + ssh_tunnel_credentials = SSHTunnel( **self._properties["ssh_tunnel_credentials"] ) diff --git a/superset/databases/models.py b/superset/databases/models.py index 57e47aa6f3b0..1fe255085724 100644 --- a/superset/databases/models.py +++ b/superset/databases/models.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Dict +from typing import Any, Dict import sqlalchemy as sa from flask_appbuilder import Model @@ -33,14 +33,12 @@ app_config = app.config -class SSHTunnelCredentials( - Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin -): +class SSHTunnel(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin): """ A ssh tunnel configuration in a database. """ - __tablename__ = "ssh_tunnel_credentials" + __tablename__ = "ssh_tunnel" id = sa.Column(sa.Integer, primary_key=True) database_id = sa.Column(sa.Integer, sa.ForeignKey("dbs.id"), nullable=False) @@ -70,7 +68,7 @@ class SSHTunnelCredentials( bind_host = sa.Column(EncryptedType(sa.String, app_config["SECRET_KEY"])) bind_port = sa.Column(EncryptedType(sa.Integer, app_config["SECRET_KEY"])) - def parameters(self) -> Dict: + def parameters(self) -> Dict[str, Any]: params = { "ssh_address_or_host": self.server_address, "ssh_port": self.server_port, diff --git a/superset/databases/schemas.py b/superset/databases/schemas.py index 250321e12b22..fad2da0feee4 100644 --- a/superset/databases/schemas.py +++ b/superset/databases/schemas.py @@ -456,7 +456,7 @@ class Meta: # pylint: disable=too-few-public-methods external_url = fields.String(allow_none=True) -class DatabaseSSHTunnelCredentials(Schema): +class DatabaseSSHTunnel(Schema): id = fields.Integer() database_id = fields.Integer() @@ -502,9 +502,7 @@ class DatabaseTestConnectionSchema(Schema, DatabaseParametersSchemaMixin): validate=[Length(1, 1024), sqlalchemy_uri_validator], ) - ssh_tunnel_credentials = fields.Nested( - DatabaseSSHTunnelCredentials, allow_none=True - ) + ssh_tunnel_credentials = fields.Nested(DatabaseSSHTunnel, allow_none=True) class TableMetadataOptionsResponseSchema(Schema): diff --git a/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py b/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py index b4c57470f71c..9023caa01791 100644 --- a/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py +++ b/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py @@ -72,4 +72,4 @@ def upgrade(): def downgrade(): - op.drop_table("ssh_tunnel_credentials") + op.drop_table("ssh_tunnel") diff --git a/superset/models/core.py b/superset/models/core.py index 3bc109dbe0be..6adb1189db1d 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -24,7 +24,7 @@ from contextlib import closing, contextmanager from copy import deepcopy from datetime import datetime -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TYPE_CHECKING import numpy import pandas as pd @@ -72,6 +72,9 @@ metadata = Model.metadata # pylint: disable=no-member logger = logging.getLogger(__name__) +if TYPE_CHECKING: + from superset.databases.models import SSHTunnel + DB_CONNECTION_MUTATOR = config["DB_CONNECTION_MUTATOR"] @@ -369,23 +372,24 @@ def get_sqla_engine_with_context( schema: Optional[str] = None, nullpool: bool = True, source: Optional[utils.QuerySource] = None, - ssh_tunnel_credentials: Optional["SSHTunnelCredentials"] = None + ssh_tunnel_credentials: Optional["SSHTunnel"] = None ) -> Engine: + ssh_params = None if ssh_tunnel_credentials: # build with override print('building with params') + url = make_url_safe(self.sqlalchemy_uri_decrypted) + ssh_tunnel_credentials.bind_host = url.host + ssh_tunnel_credentials.bind_port = url.port + ssh_params = ssh_tunnel_credentials.parameters() else: # do look up in table for using database_id print('doing look up on table') try: - from sqlalchemy.engine.url import make_url - url = make_url(self.sqlalchemy_uri_decrypted) - ssh_tunnel_credentials.bind_host = url.host - ssh_tunnel_credentials.bind_port = url.port with sshtunnel.open_tunnel( - **ssh_tunnel_credentials.parameters() + **ssh_params ) as server: yield self._get_sqla_engine(schema=schema, nullpool=nullpool, source=source, ssh_tunnel_server=server) except Exception as ex: @@ -444,7 +448,7 @@ def _get_sqla_engine( if ssh_tunnel_server: # update sqlalchemy_url from sqlalchemy.engine.url import make_url - url = make_url(sqlalchemy_url) + url = make_url_safe(sqlalchemy_url) sqlalchemy_url = url.set(host="127.0.0.1", port=ssh_tunnel_server.local_bind_port) try: From ec27b80bacf7b815f7c9390ad0848cf3d238fc53 Mon Sep 17 00:00:00 2001 From: hughhhh Date: Thu, 10 Nov 2022 11:30:29 -0500 Subject: [PATCH 32/75] fix renaming --- superset/databases/commands/test_connection.py | 8 ++++---- superset/databases/models.py | 2 +- superset/models/core.py | 10 +++++----- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/superset/databases/commands/test_connection.py b/superset/databases/commands/test_connection.py index fca146a9fa1b..f98c7657eea2 100644 --- a/superset/databases/commands/test_connection.py +++ b/superset/databases/commands/test_connection.py @@ -92,9 +92,9 @@ def run(self) -> None: # pylint: disable=too-many-statements from superset.databases.models import SSHTunnel - if self._properties.get("ssh_tunnel_credentials"): - ssh_tunnel_credentials = SSHTunnel( - **self._properties["ssh_tunnel_credentials"] + if self._properties.get("ssh_tunnel"): + ssh_tunnel = SSHTunnel( + **self._properties["ssh_tunnel"] ) event_logger.log_with_context( @@ -107,7 +107,7 @@ def ping(engine: Engine) -> bool: return engine.dialect.do_ping(conn) with database.get_sqla_engine_with_context( - ssh_tunnel_credentials=ssh_tunnel_credentials + ssh_tunnel=ssh_tunnel ) as engine: try: alive = func_timeout( diff --git a/superset/databases/models.py b/superset/databases/models.py index 1fe255085724..02ece395ea26 100644 --- a/superset/databases/models.py +++ b/superset/databases/models.py @@ -44,7 +44,7 @@ class SSHTunnel(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin): database_id = sa.Column(sa.Integer, sa.ForeignKey("dbs.id"), nullable=False) database: Database = relationship( "Database", - backref=backref("ssh_tunnel_credentials", cascade="all, delete-orphan"), + backref=backref("ssh_tunnel", cascade="all, delete-orphan"), foreign_keys=[database_id], ) diff --git a/superset/models/core.py b/superset/models/core.py index 6adb1189db1d..2e9dcb8e68ba 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -372,16 +372,16 @@ def get_sqla_engine_with_context( schema: Optional[str] = None, nullpool: bool = True, source: Optional[utils.QuerySource] = None, - ssh_tunnel_credentials: Optional["SSHTunnel"] = None + ssh_tunnel: Optional["SSHTunnel"] = None ) -> Engine: ssh_params = None - if ssh_tunnel_credentials: + if ssh_tunnel: # build with override print('building with params') url = make_url_safe(self.sqlalchemy_uri_decrypted) - ssh_tunnel_credentials.bind_host = url.host - ssh_tunnel_credentials.bind_port = url.port - ssh_params = ssh_tunnel_credentials.parameters() + ssh_tunnel.bind_host = url.host + ssh_tunnel.bind_port = url.port + ssh_params = ssh_tunnel.parameters() else: # do look up in table for using database_id From 65e3e29386c32f2346f24440277da32e459b4d16 Mon Sep 17 00:00:00 2001 From: hughhhh Date: Thu, 10 Nov 2022 11:50:57 -0500 Subject: [PATCH 33/75] fix renaming 2 --- superset/databases/commands/test_connection.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/superset/databases/commands/test_connection.py b/superset/databases/commands/test_connection.py index f98c7657eea2..55647826fdd6 100644 --- a/superset/databases/commands/test_connection.py +++ b/superset/databases/commands/test_connection.py @@ -93,9 +93,7 @@ def run(self) -> None: # pylint: disable=too-many-statements from superset.databases.models import SSHTunnel if self._properties.get("ssh_tunnel"): - ssh_tunnel = SSHTunnel( - **self._properties["ssh_tunnel"] - ) + ssh_tunnel = SSHTunnel(**self._properties["ssh_tunnel"]) event_logger.log_with_context( action="test_connection_attempt", @@ -106,9 +104,7 @@ def ping(engine: Engine) -> bool: with closing(engine.raw_connection()) as conn: return engine.dialect.do_ping(conn) - with database.get_sqla_engine_with_context( - ssh_tunnel=ssh_tunnel - ) as engine: + with database.get_sqla_engine_with_context(ssh_tunnel=ssh_tunnel) as engine: try: alive = func_timeout( int( From 1a11ff4dc31193c0659f2c579950b41463788a87 Mon Sep 17 00:00:00 2001 From: hughhhh Date: Thu, 10 Nov 2022 12:16:46 -0500 Subject: [PATCH 34/75] oops --- requirements/development.txt | 2 +- requirements/docker.txt | 2 +- requirements/local.txt | 2 +- superset/models/core.py | 21 ++++++++++++--------- 4 files changed, 15 insertions(+), 12 deletions(-) diff --git a/requirements/development.txt b/requirements/development.txt index adb5ed36cd5a..1bce530eab73 100644 --- a/requirements/development.txt +++ b/requirements/development.txt @@ -6,7 +6,7 @@ # pip-compile-multi # -r base.txt --e file:///Users/hugh/src/superset +-e file:. # via # -r requirements/base.in # -r requirements/development.in diff --git a/requirements/docker.txt b/requirements/docker.txt index 9d88a0fecb53..307064dbdedd 100644 --- a/requirements/docker.txt +++ b/requirements/docker.txt @@ -6,7 +6,7 @@ # pip-compile-multi # -r base.txt --e file:///Users/hugh/src/superset +-e file:. # via # -r requirements/base.in # -r requirements/docker.in diff --git a/requirements/local.txt b/requirements/local.txt index 10280f670daf..c4bd3cd599b3 100644 --- a/requirements/local.txt +++ b/requirements/local.txt @@ -6,7 +6,7 @@ # pip-compile-multi # -r development.txt --e file:///Users/hugh/src/superset +-e file:. # via # -r requirements/base.in # -r requirements/development.in diff --git a/superset/models/core.py b/superset/models/core.py index 2e9dcb8e68ba..a7c278573ce7 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -374,7 +374,7 @@ def get_sqla_engine_with_context( source: Optional[utils.QuerySource] = None, ssh_tunnel: Optional["SSHTunnel"] = None ) -> Engine: - ssh_params = None + ssh_params = {} if ssh_tunnel: # build with override print('building with params') @@ -382,18 +382,21 @@ def get_sqla_engine_with_context( ssh_tunnel.bind_host = url.host ssh_tunnel.bind_port = url.port ssh_params = ssh_tunnel.parameters() + try: + with sshtunnel.open_tunnel( + **ssh_params + ) as server: + yield self._get_sqla_engine(schema=schema, nullpool=nullpool, source=source, ssh_tunnel_server=server) + except Exception as ex: + raise ex else: # do look up in table for using database_id print('doing look up on table') - - try: - with sshtunnel.open_tunnel( - **ssh_params - ) as server: - yield self._get_sqla_engine(schema=schema, nullpool=nullpool, source=source, ssh_tunnel_server=server) - except Exception as ex: - raise ex + try: + yield self._get_sqla_engine(schema=schema, nullpool=nullpool, source=source) + except Exception as ex: + raise ex import sshtunnel def _get_sqla_engine( From 3f0dae1e60dc6a978c0ec127289235ff3c36fa48 Mon Sep 17 00:00:00 2001 From: hughhhh Date: Thu, 10 Nov 2022 16:49:17 -0500 Subject: [PATCH 35/75] fix linting errors --- .../databases/commands/test_connection.py | 3 +-- superset/models/core.py | 27 +++++++++++-------- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/superset/databases/commands/test_connection.py b/superset/databases/commands/test_connection.py index 55647826fdd6..bebd9fc57465 100644 --- a/superset/databases/commands/test_connection.py +++ b/superset/databases/commands/test_connection.py @@ -32,6 +32,7 @@ DatabaseTestConnectionUnexpectedError, ) from superset.databases.dao import DatabaseDAO +from superset.databases.models import SSHTunnel from superset.databases.utils import make_url_safe from superset.errors import ErrorLevel, SupersetErrorType from superset.exceptions import ( @@ -90,8 +91,6 @@ def run(self) -> None: # pylint: disable=too-many-statements database.set_sqlalchemy_uri(uri) database.db_engine_spec.mutate_db_for_connection_test(database) - from superset.databases.models import SSHTunnel - if self._properties.get("ssh_tunnel"): ssh_tunnel = SSHTunnel(**self._properties["ssh_tunnel"]) diff --git a/superset/models/core.py b/superset/models/core.py index a7c278573ce7..d7a8419a6bb6 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -372,33 +372,37 @@ def get_sqla_engine_with_context( schema: Optional[str] = None, nullpool: bool = True, source: Optional[utils.QuerySource] = None, - ssh_tunnel: Optional["SSHTunnel"] = None + ssh_tunnel: Optional["SSHTunnel"] = None, ) -> Engine: ssh_params = {} if ssh_tunnel: # build with override - print('building with params') + print("building with params") url = make_url_safe(self.sqlalchemy_uri_decrypted) ssh_tunnel.bind_host = url.host ssh_tunnel.bind_port = url.port ssh_params = ssh_tunnel.parameters() try: - with sshtunnel.open_tunnel( - **ssh_params - ) as server: - yield self._get_sqla_engine(schema=schema, nullpool=nullpool, source=source, ssh_tunnel_server=server) + with sshtunnel.open_tunnel(**ssh_params) as server: + yield self._get_sqla_engine( + schema=schema, + nullpool=nullpool, + source=source, + ssh_tunnel_server=server, + ) except Exception as ex: raise ex else: # do look up in table for using database_id - print('doing look up on table') + print("doing look up on table") try: - yield self._get_sqla_engine(schema=schema, nullpool=nullpool, source=source) + yield self._get_sqla_engine( + schema=schema, nullpool=nullpool, source=source + ) except Exception as ex: raise ex - import sshtunnel def _get_sqla_engine( self, schema: Optional[str] = None, @@ -450,9 +454,10 @@ def _get_sqla_engine( if ssh_tunnel_server: # update sqlalchemy_url - from sqlalchemy.engine.url import make_url url = make_url_safe(sqlalchemy_url) - sqlalchemy_url = url.set(host="127.0.0.1", port=ssh_tunnel_server.local_bind_port) + sqlalchemy_url = url.set( + host="127.0.0.1", port=ssh_tunnel_server.local_bind_port + ) try: return create_engine(sqlalchemy_url, **params) From 27778070cddbaf97ad80890bcab4858d05d52098 Mon Sep 17 00:00:00 2001 From: Antonio Rivero Martinez <38889534+Antonio-RiveroMartnez@users.noreply.github.com> Date: Tue, 15 Nov 2022 18:32:09 -0300 Subject: [PATCH 36/75] feat(ssh_tunnel): DAO Changes for SSH Tunnel (#22120) --- superset/databases/dao.py | 11 ++++ superset/databases/ssh_tunnel_dao.py | 26 ++++++++ tests/unit_tests/databases/dao/__init__.py | 16 +++++ tests/unit_tests/databases/dao/dao_tests.py | 70 +++++++++++++++++++++ 4 files changed, 123 insertions(+) create mode 100644 superset/databases/ssh_tunnel_dao.py create mode 100644 tests/unit_tests/databases/dao/__init__.py create mode 100644 tests/unit_tests/databases/dao/dao_tests.py diff --git a/superset/databases/dao.py b/superset/databases/dao.py index 568755dd3272..c8565e025a85 100644 --- a/superset/databases/dao.py +++ b/superset/databases/dao.py @@ -19,6 +19,7 @@ from superset.dao.base import BaseDAO from superset.databases.filters import DatabaseFilter +from superset.databases.models import SSHTunnel from superset.extensions import db from superset.models.core import Database from superset.models.dashboard import Dashboard @@ -124,3 +125,13 @@ def get_related_objects(cls, database_id: int) -> Dict[str, Any]: return dict( charts=charts, dashboards=dashboards, sqllab_tab_states=sqllab_tab_states ) + + @classmethod + def get_ssh_tunnel(cls, database_id: int) -> Dict[str, Any]: + ssh_tunnel = ( + db.session.query(SSHTunnel) + .filter(SSHTunnel.database_id == database_id) + .one_or_none() + ) + + return dict(ssh_tunnel=ssh_tunnel) diff --git a/superset/databases/ssh_tunnel_dao.py b/superset/databases/ssh_tunnel_dao.py new file mode 100644 index 000000000000..b7afaa388200 --- /dev/null +++ b/superset/databases/ssh_tunnel_dao.py @@ -0,0 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import logging + +from superset.dao.base import BaseDAO +from superset.databases.models import SSHTunnel + +logger = logging.getLogger(__name__) + + +class SSHTunnelDAO(BaseDAO): + model_cls = SSHTunnel diff --git a/tests/unit_tests/databases/dao/__init__.py b/tests/unit_tests/databases/dao/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/tests/unit_tests/databases/dao/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/databases/dao/dao_tests.py b/tests/unit_tests/databases/dao/dao_tests.py new file mode 100644 index 000000000000..5a6c2114b9b2 --- /dev/null +++ b/tests/unit_tests/databases/dao/dao_tests.py @@ -0,0 +1,70 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Iterator + +import pytest +from sqlalchemy.orm.session import Session + + +@pytest.fixture +def session_with_data(session: Session) -> Iterator[Session]: + from superset.connectors.sqla.models import SqlaTable + from superset.databases.models import SSHTunnel + from superset.models.core import Database + + engine = session.get_bind() + SqlaTable.metadata.create_all(engine) # pylint: disable=no-member + + db = Database(database_name="my_database", sqlalchemy_uri="sqlite://") + sqla_table = SqlaTable( + table_name="my_sqla_table", + columns=[], + metrics=[], + database=db, + ) + ssh_tunnel = SSHTunnel( + database_id=db.id, + database=db, + ) + + session.add(db) + session.add(sqla_table) + session.add(ssh_tunnel) + session.flush() + yield session + session.rollback() + + +def test_database_get_shh_tunnel(session_with_data: Session) -> None: + from superset.databases.dao import DatabaseDAO + from superset.databases.models import SSHTunnel + + result = DatabaseDAO.get_ssh_tunnel(1) + + assert result + assert isinstance(result["ssh_tunnel"], SSHTunnel) + assert 1 == result["ssh_tunnel"].database_id + + +def test_database_get_shh_tunnel_not_found(session_with_data: Session) -> None: + from superset.databases.dao import DatabaseDAO + + result = DatabaseDAO.get_ssh_tunnel(2) + + assert result + assert result["ssh_tunnel"] is None From 6bd32e8abb84f7bf0d2dca5d7f3f436579512ec3 Mon Sep 17 00:00:00 2001 From: Antonio Rivero Martinez <38889534+Antonio-RiveroMartnez@users.noreply.github.com> Date: Wed, 16 Nov 2022 15:44:32 -0300 Subject: [PATCH 37/75] feat(ssh_tunnel): Delete command & exceptions (#22131) --- .../databases/commands/test_connection.py | 2 +- superset/databases/dao.py | 2 +- superset/databases/ssh_tunnel/__init__.py | 16 +++++ .../databases/ssh_tunnel/commands/delete.py | 52 ++++++++++++++ .../ssh_tunnel/commands/exceptions.py | 28 ++++++++ .../{ssh_tunnel_dao.py => ssh_tunnel/dao.py} | 2 +- superset/databases/{ => ssh_tunnel}/models.py | 0 superset/models/core.py | 2 +- tests/unit_tests/databases/dao/dao_tests.py | 4 +- .../databases/ssh_tunnel/__init__.py | 16 +++++ .../databases/ssh_tunnel/commands/__init__.py | 16 +++++ .../ssh_tunnel/commands/delete_test.py | 69 +++++++++++++++++++ 12 files changed, 203 insertions(+), 6 deletions(-) create mode 100644 superset/databases/ssh_tunnel/__init__.py create mode 100644 superset/databases/ssh_tunnel/commands/delete.py create mode 100644 superset/databases/ssh_tunnel/commands/exceptions.py rename superset/databases/{ssh_tunnel_dao.py => ssh_tunnel/dao.py} (94%) rename superset/databases/{ => ssh_tunnel}/models.py (100%) create mode 100644 tests/unit_tests/databases/ssh_tunnel/__init__.py create mode 100644 tests/unit_tests/databases/ssh_tunnel/commands/__init__.py create mode 100644 tests/unit_tests/databases/ssh_tunnel/commands/delete_test.py diff --git a/superset/databases/commands/test_connection.py b/superset/databases/commands/test_connection.py index bc241b2bcd8f..106452589ba2 100644 --- a/superset/databases/commands/test_connection.py +++ b/superset/databases/commands/test_connection.py @@ -32,7 +32,7 @@ DatabaseTestConnectionUnexpectedError, ) from superset.databases.dao import DatabaseDAO -from superset.databases.models import SSHTunnel +from superset.databases.ssh_tunnel.models import SSHTunnel from superset.databases.utils import make_url_safe from superset.errors import ErrorLevel, SupersetErrorType from superset.exceptions import ( diff --git a/superset/databases/dao.py b/superset/databases/dao.py index c8565e025a85..e947db89d9c6 100644 --- a/superset/databases/dao.py +++ b/superset/databases/dao.py @@ -19,7 +19,7 @@ from superset.dao.base import BaseDAO from superset.databases.filters import DatabaseFilter -from superset.databases.models import SSHTunnel +from superset.databases.ssh_tunnel.models import SSHTunnel from superset.extensions import db from superset.models.core import Database from superset.models.dashboard import Dashboard diff --git a/superset/databases/ssh_tunnel/__init__.py b/superset/databases/ssh_tunnel/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/superset/databases/ssh_tunnel/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/superset/databases/ssh_tunnel/commands/delete.py b/superset/databases/ssh_tunnel/commands/delete.py new file mode 100644 index 000000000000..d39e395a34cd --- /dev/null +++ b/superset/databases/ssh_tunnel/commands/delete.py @@ -0,0 +1,52 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import logging +from typing import Optional + +from flask_appbuilder.models.sqla import Model + +from superset.commands.base import BaseCommand +from superset.dao.exceptions import DAODeleteFailedError +from superset.databases.ssh_tunnel.commands.exceptions import ( + SSHTunnelDeleteFailedError, + SSHTunnelNotFoundError, +) +from superset.databases.ssh_tunnel.dao import SSHTunnelDAO +from superset.databases.ssh_tunnel.models import SSHTunnel + +logger = logging.getLogger(__name__) + + +class DeleteSSHTunnelCommand(BaseCommand): + def __init__(self, model_id: int): + self._model_id = model_id + self._model: Optional[SSHTunnel] = None + + def run(self) -> Model: + self.validate() + try: + ssh_tunnel = SSHTunnelDAO.delete(self._model) + except DAODeleteFailedError as ex: + logger.exception(ex.exception) + raise SSHTunnelDeleteFailedError() from ex + return ssh_tunnel + + def validate(self) -> None: + # Validate/populate model exists + self._model = SSHTunnelDAO.find_by_id(self._model_id) + if not self._model: + raise SSHTunnelNotFoundError() diff --git a/superset/databases/ssh_tunnel/commands/exceptions.py b/superset/databases/ssh_tunnel/commands/exceptions.py new file mode 100644 index 000000000000..f535e6ced86f --- /dev/null +++ b/superset/databases/ssh_tunnel/commands/exceptions.py @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from flask_babel import lazy_gettext as _ + +from superset.commands.exceptions import CommandException, DeleteFailedError + + +class SSHTunnelDeleteFailedError(DeleteFailedError): + message = _("SSH Tunnel could not be deleted.") + + +class SSHTunnelNotFoundError(CommandException): + status = 404 + message = _("SSH Tunnel not found.") diff --git a/superset/databases/ssh_tunnel_dao.py b/superset/databases/ssh_tunnel/dao.py similarity index 94% rename from superset/databases/ssh_tunnel_dao.py rename to superset/databases/ssh_tunnel/dao.py index b7afaa388200..924148164482 100644 --- a/superset/databases/ssh_tunnel_dao.py +++ b/superset/databases/ssh_tunnel/dao.py @@ -17,7 +17,7 @@ import logging from superset.dao.base import BaseDAO -from superset.databases.models import SSHTunnel +from superset.databases.ssh_tunnel.models import SSHTunnel logger = logging.getLogger(__name__) diff --git a/superset/databases/models.py b/superset/databases/ssh_tunnel/models.py similarity index 100% rename from superset/databases/models.py rename to superset/databases/ssh_tunnel/models.py diff --git a/superset/models/core.py b/superset/models/core.py index 40679bb79eb1..d71b03137d96 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -73,7 +73,7 @@ logger = logging.getLogger(__name__) if TYPE_CHECKING: - from superset.databases.models import SSHTunnel + from superset.databases.ssh_tunnel.models import SSHTunnel DB_CONNECTION_MUTATOR = config["DB_CONNECTION_MUTATOR"] diff --git a/tests/unit_tests/databases/dao/dao_tests.py b/tests/unit_tests/databases/dao/dao_tests.py index 5a6c2114b9b2..a5a828d79da7 100644 --- a/tests/unit_tests/databases/dao/dao_tests.py +++ b/tests/unit_tests/databases/dao/dao_tests.py @@ -24,7 +24,7 @@ @pytest.fixture def session_with_data(session: Session) -> Iterator[Session]: from superset.connectors.sqla.models import SqlaTable - from superset.databases.models import SSHTunnel + from superset.databases.ssh_tunnel.models import SSHTunnel from superset.models.core import Database engine = session.get_bind() @@ -52,7 +52,7 @@ def session_with_data(session: Session) -> Iterator[Session]: def test_database_get_shh_tunnel(session_with_data: Session) -> None: from superset.databases.dao import DatabaseDAO - from superset.databases.models import SSHTunnel + from superset.databases.ssh_tunnel.models import SSHTunnel result = DatabaseDAO.get_ssh_tunnel(1) diff --git a/tests/unit_tests/databases/ssh_tunnel/__init__.py b/tests/unit_tests/databases/ssh_tunnel/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/tests/unit_tests/databases/ssh_tunnel/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/databases/ssh_tunnel/commands/__init__.py b/tests/unit_tests/databases/ssh_tunnel/commands/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/tests/unit_tests/databases/ssh_tunnel/commands/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/databases/ssh_tunnel/commands/delete_test.py b/tests/unit_tests/databases/ssh_tunnel/commands/delete_test.py new file mode 100644 index 000000000000..b0967228e291 --- /dev/null +++ b/tests/unit_tests/databases/ssh_tunnel/commands/delete_test.py @@ -0,0 +1,69 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Iterator + +import pytest +from sqlalchemy.orm.session import Session + + +@pytest.fixture +def session_with_data(session: Session) -> Iterator[Session]: + from superset.connectors.sqla.models import SqlaTable + from superset.databases.ssh_tunnel.models import SSHTunnel + from superset.models.core import Database + + engine = session.get_bind() + SqlaTable.metadata.create_all(engine) # pylint: disable=no-member + + db = Database(database_name="my_database", sqlalchemy_uri="sqlite://") + sqla_table = SqlaTable( + table_name="my_sqla_table", + columns=[], + metrics=[], + database=db, + ) + ssh_tunnel = SSHTunnel( + database_id=db.id, + database=db, + ) + + session.add(db) + session.add(sqla_table) + session.add(ssh_tunnel) + session.flush() + yield session + session.rollback() + + +def test_delete_shh_tunnel_command(session_with_data: Session) -> None: + from superset.databases.dao import DatabaseDAO + from superset.databases.ssh_tunnel.commands.delete import DeleteSSHTunnelCommand + from superset.databases.ssh_tunnel.models import SSHTunnel + + result = DatabaseDAO.get_ssh_tunnel(1) + + assert result + assert isinstance(result["ssh_tunnel"], SSHTunnel) + assert 1 == result["ssh_tunnel"].database_id + + DeleteSSHTunnelCommand(1).run() + + result = DatabaseDAO.get_ssh_tunnel(1) + + assert result + assert result["ssh_tunnel"] is None From adb94517ae52a45641d7aba4486826d7ff6dee6e Mon Sep 17 00:00:00 2001 From: hughhhh Date: Wed, 16 Nov 2022 23:33:36 -0500 Subject: [PATCH 38/75] fix indenting for superset/databases/commands/validate.py --- superset/databases/commands/validate.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/superset/databases/commands/validate.py b/superset/databases/commands/validate.py index 05e139bf817d..8c58ef5de0bf 100644 --- a/superset/databases/commands/validate.py +++ b/superset/databases/commands/validate.py @@ -118,14 +118,14 @@ def run(self) -> None: errors = database.db_engine_spec.extract_errors(ex, context) raise DatabaseTestConnectionFailedError(errors) from ex - if not alive: - raise DatabaseOfflineError( - SupersetError( - message=__("Database is offline."), - error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR, - level=ErrorLevel.ERROR, - ), - ) + if not alive: + raise DatabaseOfflineError( + SupersetError( + message=__("Database is offline."), + error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR, + level=ErrorLevel.ERROR, + ), + ) def validate(self) -> None: database_id = self._properties.get("id") From 16d960b382fe203a0c855ba425144245e38b97fb Mon Sep 17 00:00:00 2001 From: hughhhh Date: Wed, 16 Nov 2022 23:38:27 -0500 Subject: [PATCH 39/75] change tablename --- superset/databases/ssh_tunnel/models.py | 4 ++-- ...20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/superset/databases/ssh_tunnel/models.py b/superset/databases/ssh_tunnel/models.py index 02ece395ea26..ffbab0186ebd 100644 --- a/superset/databases/ssh_tunnel/models.py +++ b/superset/databases/ssh_tunnel/models.py @@ -38,13 +38,13 @@ class SSHTunnel(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin): A ssh tunnel configuration in a database. """ - __tablename__ = "ssh_tunnel" + __tablename__ = "ssh_tunnels" id = sa.Column(sa.Integer, primary_key=True) database_id = sa.Column(sa.Integer, sa.ForeignKey("dbs.id"), nullable=False) database: Database = relationship( "Database", - backref=backref("ssh_tunnel", cascade="all, delete-orphan"), + backref=backref("ssh_tunnels", cascade="all, delete-orphan"), foreign_keys=[database_id], ) diff --git a/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py b/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py index 9023caa01791..e0b874fccdcf 100644 --- a/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py +++ b/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py @@ -40,7 +40,7 @@ def upgrade(): op.create_table( - "ssh_tunnel_credentials", + "ssh_tunnels", # AuditMixinNullable sa.Column("created_on", sa.DateTime(), nullable=True), sa.Column("changed_on", sa.DateTime(), nullable=True), @@ -72,4 +72,4 @@ def upgrade(): def downgrade(): - op.drop_table("ssh_tunnel") + op.drop_table("ssh_tunnels") From d2ab4a68789a44729c9c7c331f17c536816e55bc Mon Sep 17 00:00:00 2001 From: Antonio Rivero Martinez <38889534+Antonio-RiveroMartnez@users.noreply.github.com> Date: Thu, 17 Nov 2022 14:55:16 -0300 Subject: [PATCH 40/75] feat(ssh_tunnel): DELETE SSH Tunnels API (#22153) --- superset/constants.py | 1 + superset/databases/api.py | 61 +++++++++++++++++++++ tests/unit_tests/databases/api_test.py | 73 ++++++++++++++++++++++++++ 3 files changed, 135 insertions(+) diff --git a/superset/constants.py b/superset/constants.py index 7d759acf6741..b775926b2158 100644 --- a/superset/constants.py +++ b/superset/constants.py @@ -136,6 +136,7 @@ class RouteMethod: # pylint: disable=too-few-public-methods "validate_sql": "read", "get_data": "read", "samples": "read", + "delete_ssh_tunnel": "write", } EXTRA_FORM_DATA_APPEND_KEYS = { diff --git a/superset/databases/api.py b/superset/databases/api.py index aced8e7c6faa..92b9635fd8e4 100644 --- a/superset/databases/api.py +++ b/superset/databases/api.py @@ -72,6 +72,11 @@ ValidateSQLRequest, ValidateSQLResponse, ) +from superset.databases.ssh_tunnel.commands.delete import DeleteSSHTunnelCommand +from superset.databases.ssh_tunnel.commands.exceptions import ( + SSHTunnelDeleteFailedError, + SSHTunnelNotFoundError, +) from superset.databases.utils import get_table_metadata from superset.db_engine_specs import get_available_engine_specs from superset.errors import ErrorLevel, SupersetError, SupersetErrorType @@ -107,6 +112,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi): "available", "validate_parameters", "validate_sql", + "delete_ssh_tunnel", } resource_name = "database" class_permission_name = "Database" @@ -1204,3 +1210,58 @@ def validate_parameters(self) -> FlaskResponse: command = ValidateDatabaseParametersCommand(payload) command.run() return self.response(200, message="OK") + + @expose("//ssh_tunnel/", methods=["DELETE"]) + @protect() + @safe + @statsd_metrics + @event_logger.log_this_with_context( + action=lambda self, *args, **kwargs: f"{self.__class__.__name__}" + f".delete_ssh_tunnel", + log_to_statsd=False, + ) + def delete_ssh_tunnel(self, pk: int) -> Response: + """Deletes a SSH Tunnel + --- + delete: + description: >- + Deletes a SSH Tunnel. + parameters: + - in: path + schema: + type: integer + name: pk + responses: + 200: + description: SSH Tunnel deleted + content: + application/json: + schema: + type: object + properties: + message: + type: string + 401: + $ref: '#/components/responses/401' + 403: + $ref: '#/components/responses/403' + 404: + $ref: '#/components/responses/404' + 422: + $ref: '#/components/responses/422' + 500: + $ref: '#/components/responses/500' + """ + try: + DeleteSSHTunnelCommand(pk).run() + return self.response(200, message="OK") + except SSHTunnelNotFoundError: + return self.response_404() + except SSHTunnelDeleteFailedError as ex: + logger.error( + "Error deleting SSH Tunnel %s: %s", + self.__class__.__name__, + str(ex), + exc_info=True, + ) + return self.response_422(message=str(ex)) diff --git a/tests/unit_tests/databases/api_test.py b/tests/unit_tests/databases/api_test.py index d6f8897c4a09..9d95653a22f4 100644 --- a/tests/unit_tests/databases/api_test.py +++ b/tests/unit_tests/databases/api_test.py @@ -191,3 +191,76 @@ def test_non_zip_import(client: Any, full_api_access: None) -> None: } ] } + + +def test_delete_ssh_tunnel( + mocker: MockFixture, + app: Any, + session: Session, + client: Any, + full_api_access: None, +) -> None: + """ + Test that we can delete SSH Tunnel + """ + with app.app_context(): + from superset.databases.api import DatabaseRestApi + from superset.databases.dao import DatabaseDAO + from superset.databases.ssh_tunnel.models import SSHTunnel + from superset.models.core import Database + + DatabaseRestApi.datamodel.session = session + + # create table for databases + Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member + + # Create our Database + database = Database( + database_name="my_database", + sqlalchemy_uri="gsheets://", + encrypted_extra=json.dumps( + { + "service_account_info": { + "type": "service_account", + "project_id": "black-sanctum-314419", + "private_key_id": "259b0d419a8f840056158763ff54d8b08f7b8173", + "private_key": "SECRET", + "client_email": "google-spreadsheets-demo-servi@black-sanctum-314419.iam.gserviceaccount.com", + "client_id": "SSH_TUNNEL_CREDENTIALS_CLIENT", + "auth_uri": "https://accounts.google.com/o/oauth2/auth", + "token_uri": "https://oauth2.googleapis.com/token", + "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", + "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/google-spreadsheets-demo-servi%40black-sanctum-314419.iam.gserviceaccount.com", + }, + } + ), + ) + session.add(database) + session.commit() + + # mock the lookup so that we don't need to include the driver + mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets") + mocker.patch("superset.utils.log.DBEventLogger.log") + + # Create our SSHTunnel + tunnel = SSHTunnel( + database_id=1, + database=database, + ) + + session.add(tunnel) + session.commit() + + # Get our recently created SSHTunnel + response_tunnel = DatabaseDAO.get_ssh_tunnel(1) + assert response_tunnel + assert isinstance(response_tunnel["ssh_tunnel"], SSHTunnel) + assert 1 == response_tunnel["ssh_tunnel"].database_id + + # Delete the recently created SSHTunnel + response_delete_tunnel = client.delete("/api/v1/database/1/ssh_tunnel/") + assert response_delete_tunnel.json["message"] == "OK" + + response_tunnel = DatabaseDAO.get_ssh_tunnel(1) + assert response_tunnel + assert response_tunnel["ssh_tunnel"] is None From fb2acd09ffb8cc86121637304c28d1a59dbc1bd7 Mon Sep 17 00:00:00 2001 From: "Hugh A. Miles II" Date: Thu, 17 Nov 2022 12:55:35 -0500 Subject: [PATCH 41/75] Revert "feat(ssh_tunnel): DELETE SSH Tunnels API" (#22156) --- superset/constants.py | 1 - superset/databases/api.py | 61 --------------------- tests/unit_tests/databases/api_test.py | 73 -------------------------- 3 files changed, 135 deletions(-) diff --git a/superset/constants.py b/superset/constants.py index b775926b2158..7d759acf6741 100644 --- a/superset/constants.py +++ b/superset/constants.py @@ -136,7 +136,6 @@ class RouteMethod: # pylint: disable=too-few-public-methods "validate_sql": "read", "get_data": "read", "samples": "read", - "delete_ssh_tunnel": "write", } EXTRA_FORM_DATA_APPEND_KEYS = { diff --git a/superset/databases/api.py b/superset/databases/api.py index 92b9635fd8e4..aced8e7c6faa 100644 --- a/superset/databases/api.py +++ b/superset/databases/api.py @@ -72,11 +72,6 @@ ValidateSQLRequest, ValidateSQLResponse, ) -from superset.databases.ssh_tunnel.commands.delete import DeleteSSHTunnelCommand -from superset.databases.ssh_tunnel.commands.exceptions import ( - SSHTunnelDeleteFailedError, - SSHTunnelNotFoundError, -) from superset.databases.utils import get_table_metadata from superset.db_engine_specs import get_available_engine_specs from superset.errors import ErrorLevel, SupersetError, SupersetErrorType @@ -112,7 +107,6 @@ class DatabaseRestApi(BaseSupersetModelRestApi): "available", "validate_parameters", "validate_sql", - "delete_ssh_tunnel", } resource_name = "database" class_permission_name = "Database" @@ -1210,58 +1204,3 @@ def validate_parameters(self) -> FlaskResponse: command = ValidateDatabaseParametersCommand(payload) command.run() return self.response(200, message="OK") - - @expose("//ssh_tunnel/", methods=["DELETE"]) - @protect() - @safe - @statsd_metrics - @event_logger.log_this_with_context( - action=lambda self, *args, **kwargs: f"{self.__class__.__name__}" - f".delete_ssh_tunnel", - log_to_statsd=False, - ) - def delete_ssh_tunnel(self, pk: int) -> Response: - """Deletes a SSH Tunnel - --- - delete: - description: >- - Deletes a SSH Tunnel. - parameters: - - in: path - schema: - type: integer - name: pk - responses: - 200: - description: SSH Tunnel deleted - content: - application/json: - schema: - type: object - properties: - message: - type: string - 401: - $ref: '#/components/responses/401' - 403: - $ref: '#/components/responses/403' - 404: - $ref: '#/components/responses/404' - 422: - $ref: '#/components/responses/422' - 500: - $ref: '#/components/responses/500' - """ - try: - DeleteSSHTunnelCommand(pk).run() - return self.response(200, message="OK") - except SSHTunnelNotFoundError: - return self.response_404() - except SSHTunnelDeleteFailedError as ex: - logger.error( - "Error deleting SSH Tunnel %s: %s", - self.__class__.__name__, - str(ex), - exc_info=True, - ) - return self.response_422(message=str(ex)) diff --git a/tests/unit_tests/databases/api_test.py b/tests/unit_tests/databases/api_test.py index 9d95653a22f4..d6f8897c4a09 100644 --- a/tests/unit_tests/databases/api_test.py +++ b/tests/unit_tests/databases/api_test.py @@ -191,76 +191,3 @@ def test_non_zip_import(client: Any, full_api_access: None) -> None: } ] } - - -def test_delete_ssh_tunnel( - mocker: MockFixture, - app: Any, - session: Session, - client: Any, - full_api_access: None, -) -> None: - """ - Test that we can delete SSH Tunnel - """ - with app.app_context(): - from superset.databases.api import DatabaseRestApi - from superset.databases.dao import DatabaseDAO - from superset.databases.ssh_tunnel.models import SSHTunnel - from superset.models.core import Database - - DatabaseRestApi.datamodel.session = session - - # create table for databases - Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member - - # Create our Database - database = Database( - database_name="my_database", - sqlalchemy_uri="gsheets://", - encrypted_extra=json.dumps( - { - "service_account_info": { - "type": "service_account", - "project_id": "black-sanctum-314419", - "private_key_id": "259b0d419a8f840056158763ff54d8b08f7b8173", - "private_key": "SECRET", - "client_email": "google-spreadsheets-demo-servi@black-sanctum-314419.iam.gserviceaccount.com", - "client_id": "SSH_TUNNEL_CREDENTIALS_CLIENT", - "auth_uri": "https://accounts.google.com/o/oauth2/auth", - "token_uri": "https://oauth2.googleapis.com/token", - "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", - "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/google-spreadsheets-demo-servi%40black-sanctum-314419.iam.gserviceaccount.com", - }, - } - ), - ) - session.add(database) - session.commit() - - # mock the lookup so that we don't need to include the driver - mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets") - mocker.patch("superset.utils.log.DBEventLogger.log") - - # Create our SSHTunnel - tunnel = SSHTunnel( - database_id=1, - database=database, - ) - - session.add(tunnel) - session.commit() - - # Get our recently created SSHTunnel - response_tunnel = DatabaseDAO.get_ssh_tunnel(1) - assert response_tunnel - assert isinstance(response_tunnel["ssh_tunnel"], SSHTunnel) - assert 1 == response_tunnel["ssh_tunnel"].database_id - - # Delete the recently created SSHTunnel - response_delete_tunnel = client.delete("/api/v1/database/1/ssh_tunnel/") - assert response_delete_tunnel.json["message"] == "OK" - - response_tunnel = DatabaseDAO.get_ssh_tunnel(1) - assert response_tunnel - assert response_tunnel["ssh_tunnel"] is None From 4d807c9e14e8219e80d10353516722be2e3c13e6 Mon Sep 17 00:00:00 2001 From: Antonio Rivero Martinez <38889534+Antonio-RiveroMartnez@users.noreply.github.com> Date: Thu, 17 Nov 2022 14:56:58 -0300 Subject: [PATCH 42/75] feat(ssh_tunnel): Update command & exceptions (#22132) --- .../databases/ssh_tunnel/commands/delete.py | 1 - .../ssh_tunnel/commands/exceptions.py | 15 +++- .../databases/ssh_tunnel/commands/update.py | 63 +++++++++++++++++ .../ssh_tunnel/commands/update_test.py | 69 +++++++++++++++++++ 4 files changed, 146 insertions(+), 2 deletions(-) create mode 100644 superset/databases/ssh_tunnel/commands/update.py create mode 100644 tests/unit_tests/databases/ssh_tunnel/commands/update_test.py diff --git a/superset/databases/ssh_tunnel/commands/delete.py b/superset/databases/ssh_tunnel/commands/delete.py index d39e395a34cd..3ad2fc2a1506 100644 --- a/superset/databases/ssh_tunnel/commands/delete.py +++ b/superset/databases/ssh_tunnel/commands/delete.py @@ -41,7 +41,6 @@ def run(self) -> Model: try: ssh_tunnel = SSHTunnelDAO.delete(self._model) except DAODeleteFailedError as ex: - logger.exception(ex.exception) raise SSHTunnelDeleteFailedError() from ex return ssh_tunnel diff --git a/superset/databases/ssh_tunnel/commands/exceptions.py b/superset/databases/ssh_tunnel/commands/exceptions.py index f535e6ced86f..ccceb440891c 100644 --- a/superset/databases/ssh_tunnel/commands/exceptions.py +++ b/superset/databases/ssh_tunnel/commands/exceptions.py @@ -16,7 +16,12 @@ # under the License. from flask_babel import lazy_gettext as _ -from superset.commands.exceptions import CommandException, DeleteFailedError +from superset.commands.exceptions import ( + CommandException, + CommandInvalidError, + DeleteFailedError, + UpdateFailedError, +) class SSHTunnelDeleteFailedError(DeleteFailedError): @@ -26,3 +31,11 @@ class SSHTunnelDeleteFailedError(DeleteFailedError): class SSHTunnelNotFoundError(CommandException): status = 404 message = _("SSH Tunnel not found.") + + +class SSHTunnelInvalidError(CommandInvalidError): + message = _("SSH Tunnel parameters are invalid.") + + +class SSHTunnelUpdateFailedError(UpdateFailedError): + message = _("SSH Tunnel could not be updated.") diff --git a/superset/databases/ssh_tunnel/commands/update.py b/superset/databases/ssh_tunnel/commands/update.py new file mode 100644 index 000000000000..fa076d90b2e9 --- /dev/null +++ b/superset/databases/ssh_tunnel/commands/update.py @@ -0,0 +1,63 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import logging +from typing import Any, Dict, List, Optional + +from flask_appbuilder.models.sqla import Model +from marshmallow import ValidationError + +from superset.commands.base import BaseCommand +from superset.dao.exceptions import DAOUpdateFailedError +from superset.databases.ssh_tunnel.commands.exceptions import ( + SSHTunnelDeleteFailedError, + SSHTunnelInvalidError, + SSHTunnelNotFoundError, + SSHTunnelUpdateFailedError, +) +from superset.databases.ssh_tunnel.dao import SSHTunnelDAO +from superset.databases.ssh_tunnel.models import SSHTunnel + +logger = logging.getLogger(__name__) + + +class UpdateSSHTunnelCommand(BaseCommand): + def __init__(self, model_id: int, data: Dict[str, Any]): + self._properties = data.copy() + self._model_id = model_id + self._model: Optional[SSHTunnel] = None + + def run(self) -> Model: + self.validate() + if not self._model: + raise SSHTunnelNotFoundError() + + try: + tunnel = SSHTunnelDAO.update(self._model, self._properties, commit=True) + except DAOUpdateFailedError as ex: + raise SSHTunnelUpdateFailedError() from ex + return tunnel + + def validate(self) -> None: + exceptions: List[ValidationError] = [] + # Validate/populate model exists + self._model = SSHTunnelDAO.find_by_id(self._model_id) + if not self._model: + raise SSHTunnelNotFoundError() + if exceptions: + exception = SSHTunnelInvalidError() + exception.add_list(exceptions) + raise exception diff --git a/tests/unit_tests/databases/ssh_tunnel/commands/update_test.py b/tests/unit_tests/databases/ssh_tunnel/commands/update_test.py new file mode 100644 index 000000000000..f85b984aa8ed --- /dev/null +++ b/tests/unit_tests/databases/ssh_tunnel/commands/update_test.py @@ -0,0 +1,69 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Iterator + +import pytest +from sqlalchemy.orm.session import Session + + +@pytest.fixture +def session_with_data(session: Session) -> Iterator[Session]: + from superset.connectors.sqla.models import SqlaTable + from superset.databases.ssh_tunnel.models import SSHTunnel + from superset.models.core import Database + + engine = session.get_bind() + SqlaTable.metadata.create_all(engine) # pylint: disable=no-member + + db = Database(database_name="my_database", sqlalchemy_uri="sqlite://") + sqla_table = SqlaTable( + table_name="my_sqla_table", + columns=[], + metrics=[], + database=db, + ) + ssh_tunnel = SSHTunnel(database_id=db.id, database=db, server_address="Test") + + session.add(db) + session.add(sqla_table) + session.add(ssh_tunnel) + session.flush() + yield session + session.rollback() + + +def test_update_shh_tunnel_command(session_with_data: Session) -> None: + from superset.databases.dao import DatabaseDAO + from superset.databases.ssh_tunnel.commands.update import UpdateSSHTunnelCommand + from superset.databases.ssh_tunnel.models import SSHTunnel + + result = DatabaseDAO.get_ssh_tunnel(1) + + assert result + assert isinstance(result["ssh_tunnel"], SSHTunnel) + assert 1 == result["ssh_tunnel"].database_id + assert "Test" == result["ssh_tunnel"].server_address + + update_payload = {"server_address": "Test2"} + UpdateSSHTunnelCommand(1, update_payload).run() + + result = DatabaseDAO.get_ssh_tunnel(1) + + assert result + assert isinstance(result["ssh_tunnel"], SSHTunnel) + assert "Test2" == result["ssh_tunnel"].server_address From dc0c84851637582760bbce2663a8b56af6868d28 Mon Sep 17 00:00:00 2001 From: hughhhh Date: Thu, 17 Nov 2022 14:47:11 -0500 Subject: [PATCH 43/75] forgot server_port --- superset/databases/ssh_tunnel/models.py | 2 +- ...0-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/superset/databases/ssh_tunnel/models.py b/superset/databases/ssh_tunnel/models.py index ffbab0186ebd..f7c8c9618106 100644 --- a/superset/databases/ssh_tunnel/models.py +++ b/superset/databases/ssh_tunnel/models.py @@ -49,7 +49,7 @@ class SSHTunnel(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin): ) server_address = sa.Column(EncryptedType(sa.String, app_config["SECRET_KEY"])) - server_port = sa.Column(EncryptedType(sa.String, app_config["SECRET_KEY"])) + server_port = sa.Column(EncryptedType(sa.Integer, app_config["SECRET_KEY"])) username = sa.Column(EncryptedType(sa.String, app_config["SECRET_KEY"])) # basic authentication diff --git a/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py b/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py index e0b874fccdcf..8c1d033450a2 100644 --- a/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py +++ b/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py @@ -54,6 +54,7 @@ def upgrade(): sa.Column("id", sa.Integer(), primary_key=True), sa.Column("database_id", sa.INTEGER(), sa.ForeignKey("dbs.id")), sa.Column("server_address", encrypted_field_factory.create(sa.String(1024))), + sa.Column("server_port", encrypted_field_factory.create(sa.INTEGER())), sa.Column("username", encrypted_field_factory.create(sa.String(1024))), sa.Column( "password", encrypted_field_factory.create(sa.String(1024)), nullable=True From 21fcdf0ecbf699e9d4c12893f28256058e3de979 Mon Sep 17 00:00:00 2001 From: hughhhh Date: Thu, 17 Nov 2022 15:02:36 -0500 Subject: [PATCH 44/75] bind_port + bind_host :) --- ...0-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py b/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py index 8c1d033450a2..09a4af0ed269 100644 --- a/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py +++ b/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py @@ -69,6 +69,8 @@ def upgrade(): encrypted_field_factory.create(sa.String(1024)), nullable=True, ), + sa.Column("bind_host", encrypted_field_factory.create(sa.String(1024))), + sa.Column("port_port", encrypted_field_factory.create(sa.INTEGER())), ) From 68cb75fe1745600038b6398ecfadc96d009379fd Mon Sep 17 00:00:00 2001 From: hughhhh Date: Thu, 17 Nov 2022 15:15:03 -0500 Subject: [PATCH 45/75] oops --- superset/databases/commands/test_connection.py | 6 ++++-- ..._10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/superset/databases/commands/test_connection.py b/superset/databases/commands/test_connection.py index 106452589ba2..618fc0bfe404 100644 --- a/superset/databases/commands/test_connection.py +++ b/superset/databases/commands/test_connection.py @@ -91,8 +91,10 @@ def run(self) -> None: # pylint: disable=too-many-statements database.set_sqlalchemy_uri(uri) database.db_engine_spec.mutate_db_for_connection_test(database) - if self._properties.get("ssh_tunnel"): - ssh_tunnel = SSHTunnel(**self._properties["ssh_tunnel"]) + # TODO: (hughhh) uncomment in API enablement PR + ssh_tunnel = None + # if self._properties.get("ssh_tunnel"): + # ssh_tunnel = SSHTunnel(**self._properties["ssh_tunnel"]) event_logger.log_with_context( action="test_connection_attempt", diff --git a/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py b/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py index 09a4af0ed269..690d7cd5037a 100644 --- a/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py +++ b/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py @@ -70,7 +70,7 @@ def upgrade(): nullable=True, ), sa.Column("bind_host", encrypted_field_factory.create(sa.String(1024))), - sa.Column("port_port", encrypted_field_factory.create(sa.INTEGER())), + sa.Column("bind_port", encrypted_field_factory.create(sa.INTEGER())), ) From 44ca56b187e6d659f7f247c00ed70e5071a04c74 Mon Sep 17 00:00:00 2001 From: hughhhh Date: Thu, 17 Nov 2022 16:03:39 -0500 Subject: [PATCH 46/75] fix linting --- superset/databases/commands/test_connection.py | 1 - 1 file changed, 1 deletion(-) diff --git a/superset/databases/commands/test_connection.py b/superset/databases/commands/test_connection.py index 618fc0bfe404..7b913ed20298 100644 --- a/superset/databases/commands/test_connection.py +++ b/superset/databases/commands/test_connection.py @@ -32,7 +32,6 @@ DatabaseTestConnectionUnexpectedError, ) from superset.databases.dao import DatabaseDAO -from superset.databases.ssh_tunnel.models import SSHTunnel from superset.databases.utils import make_url_safe from superset.errors import ErrorLevel, SupersetErrorType from superset.exceptions import ( From 7e1461e7de83cb5bb330068db4b5da9e64f15150 Mon Sep 17 00:00:00 2001 From: Antonio Rivero Martinez <38889534+Antonio-RiveroMartnez@users.noreply.github.com> Date: Mon, 21 Nov 2022 15:01:06 -0300 Subject: [PATCH 47/75] feat(ssh_tunnel): SSH Tunnel updates from Code Review (#22182) --- superset/databases/dao.py | 2 +- superset/databases/ssh_tunnel/commands/update.py | 10 +--------- tests/unit_tests/databases/dao/dao_tests.py | 7 +++---- .../databases/ssh_tunnel/commands/delete_test.py | 7 +++---- .../databases/ssh_tunnel/commands/update_test.py | 10 +++++----- 5 files changed, 13 insertions(+), 23 deletions(-) diff --git a/superset/databases/dao.py b/superset/databases/dao.py index e947db89d9c6..4e9ad1ee8c13 100644 --- a/superset/databases/dao.py +++ b/superset/databases/dao.py @@ -134,4 +134,4 @@ def get_ssh_tunnel(cls, database_id: int) -> Dict[str, Any]: .one_or_none() ) - return dict(ssh_tunnel=ssh_tunnel) + return ssh_tunnel diff --git a/superset/databases/ssh_tunnel/commands/update.py b/superset/databases/ssh_tunnel/commands/update.py index fa076d90b2e9..a9d9a6fe6d18 100644 --- a/superset/databases/ssh_tunnel/commands/update.py +++ b/superset/databases/ssh_tunnel/commands/update.py @@ -42,22 +42,14 @@ def __init__(self, model_id: int, data: Dict[str, Any]): def run(self) -> Model: self.validate() - if not self._model: - raise SSHTunnelNotFoundError() - try: - tunnel = SSHTunnelDAO.update(self._model, self._properties, commit=True) + tunnel = SSHTunnelDAO.update(self._model, self._properties) except DAOUpdateFailedError as ex: raise SSHTunnelUpdateFailedError() from ex return tunnel def validate(self) -> None: - exceptions: List[ValidationError] = [] # Validate/populate model exists self._model = SSHTunnelDAO.find_by_id(self._model_id) if not self._model: raise SSHTunnelNotFoundError() - if exceptions: - exception = SSHTunnelInvalidError() - exception.add_list(exceptions) - raise exception diff --git a/tests/unit_tests/databases/dao/dao_tests.py b/tests/unit_tests/databases/dao/dao_tests.py index a5a828d79da7..bb01c1d2f615 100644 --- a/tests/unit_tests/databases/dao/dao_tests.py +++ b/tests/unit_tests/databases/dao/dao_tests.py @@ -57,8 +57,8 @@ def test_database_get_shh_tunnel(session_with_data: Session) -> None: result = DatabaseDAO.get_ssh_tunnel(1) assert result - assert isinstance(result["ssh_tunnel"], SSHTunnel) - assert 1 == result["ssh_tunnel"].database_id + assert isinstance(result, SSHTunnel) + assert 1 == result.database_id def test_database_get_shh_tunnel_not_found(session_with_data: Session) -> None: @@ -66,5 +66,4 @@ def test_database_get_shh_tunnel_not_found(session_with_data: Session) -> None: result = DatabaseDAO.get_ssh_tunnel(2) - assert result - assert result["ssh_tunnel"] is None + assert result is None diff --git a/tests/unit_tests/databases/ssh_tunnel/commands/delete_test.py b/tests/unit_tests/databases/ssh_tunnel/commands/delete_test.py index b0967228e291..d70dcd43d903 100644 --- a/tests/unit_tests/databases/ssh_tunnel/commands/delete_test.py +++ b/tests/unit_tests/databases/ssh_tunnel/commands/delete_test.py @@ -58,12 +58,11 @@ def test_delete_shh_tunnel_command(session_with_data: Session) -> None: result = DatabaseDAO.get_ssh_tunnel(1) assert result - assert isinstance(result["ssh_tunnel"], SSHTunnel) - assert 1 == result["ssh_tunnel"].database_id + assert isinstance(result, SSHTunnel) + assert 1 == result.database_id DeleteSSHTunnelCommand(1).run() result = DatabaseDAO.get_ssh_tunnel(1) - assert result - assert result["ssh_tunnel"] is None + assert result is None diff --git a/tests/unit_tests/databases/ssh_tunnel/commands/update_test.py b/tests/unit_tests/databases/ssh_tunnel/commands/update_test.py index f85b984aa8ed..4b8a1fd2a002 100644 --- a/tests/unit_tests/databases/ssh_tunnel/commands/update_test.py +++ b/tests/unit_tests/databases/ssh_tunnel/commands/update_test.py @@ -55,9 +55,9 @@ def test_update_shh_tunnel_command(session_with_data: Session) -> None: result = DatabaseDAO.get_ssh_tunnel(1) assert result - assert isinstance(result["ssh_tunnel"], SSHTunnel) - assert 1 == result["ssh_tunnel"].database_id - assert "Test" == result["ssh_tunnel"].server_address + assert isinstance(result, SSHTunnel) + assert 1 == result.database_id + assert "Test" == result.server_address update_payload = {"server_address": "Test2"} UpdateSSHTunnelCommand(1, update_payload).run() @@ -65,5 +65,5 @@ def test_update_shh_tunnel_command(session_with_data: Session) -> None: result = DatabaseDAO.get_ssh_tunnel(1) assert result - assert isinstance(result["ssh_tunnel"], SSHTunnel) - assert "Test2" == result["ssh_tunnel"].server_address + assert isinstance(result, SSHTunnel) + assert "Test2" == result.server_address From 6c59663638e0293da4f243dcb8f13f5a071a79ef Mon Sep 17 00:00:00 2001 From: "Hugh A. Miles II" Date: Tue, 22 Nov 2022 13:42:25 -0500 Subject: [PATCH 48/75] feat(ssh_tunnel): Create command & exceptions (#22148) --- .../databases/ssh_tunnel/commands/create.py | 50 +++++++++++++++++++ .../ssh_tunnel/commands/exceptions.py | 4 ++ tests/unit_tests/databases/dao/dao_tests.py | 4 +- .../ssh_tunnel/commands/create_test.py | 42 ++++++++++++++++ .../ssh_tunnel/commands/delete_test.py | 2 +- .../databases/ssh_tunnel/dao_tests.py | 43 ++++++++++++++++ 6 files changed, 142 insertions(+), 3 deletions(-) create mode 100644 superset/databases/ssh_tunnel/commands/create.py create mode 100644 tests/unit_tests/databases/ssh_tunnel/commands/create_test.py create mode 100644 tests/unit_tests/databases/ssh_tunnel/dao_tests.py diff --git a/superset/databases/ssh_tunnel/commands/create.py b/superset/databases/ssh_tunnel/commands/create.py new file mode 100644 index 000000000000..1185d20bdf7c --- /dev/null +++ b/superset/databases/ssh_tunnel/commands/create.py @@ -0,0 +1,50 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import logging +from typing import Any, Dict, List, Optional + +from flask_appbuilder.models.sqla import Model +from marshmallow import ValidationError + +from superset.commands.base import BaseCommand +from superset.dao.exceptions import DAOCreateFailedError +from superset.databases.dao import DatabaseDAO +from superset.databases.ssh_tunnel.commands.exceptions import SSHTunnelCreateFailedError +from superset.databases.ssh_tunnel.dao import SSHTunnelDAO + +logger = logging.getLogger(__name__) + + +class CreateSSHTunnelCommand(BaseCommand): + def __init__(self, database_id: int, data: Dict[str, Any]): + self._properties = data.copy() + self._properties["database_id"] = database_id + + def run(self) -> Model: + self.validate() + + try: + tunnel = SSHTunnelDAO.create(self._properties, commit=False) + except DAOCreateFailedError as ex: + raise SSHTunnelCreateFailedError() from ex + + return tunnel + + def validate(self) -> None: + # TODO(hughhh): check to make sure the server port is not localhost + # using the config.SSH_TUNNEL_MANAGER + return diff --git a/superset/databases/ssh_tunnel/commands/exceptions.py b/superset/databases/ssh_tunnel/commands/exceptions.py index ccceb440891c..9e3bce81a64f 100644 --- a/superset/databases/ssh_tunnel/commands/exceptions.py +++ b/superset/databases/ssh_tunnel/commands/exceptions.py @@ -39,3 +39,7 @@ class SSHTunnelInvalidError(CommandInvalidError): class SSHTunnelUpdateFailedError(UpdateFailedError): message = _("SSH Tunnel could not be updated.") + + +class SSHTunnelCreateFailedError(CommandException): + message = _("Creating SSH Tunnel failed for an unknown reason") diff --git a/tests/unit_tests/databases/dao/dao_tests.py b/tests/unit_tests/databases/dao/dao_tests.py index bb01c1d2f615..47db402670de 100644 --- a/tests/unit_tests/databases/dao/dao_tests.py +++ b/tests/unit_tests/databases/dao/dao_tests.py @@ -50,7 +50,7 @@ def session_with_data(session: Session) -> Iterator[Session]: session.rollback() -def test_database_get_shh_tunnel(session_with_data: Session) -> None: +def test_database_get_ssh_tunnel(session_with_data: Session) -> None: from superset.databases.dao import DatabaseDAO from superset.databases.ssh_tunnel.models import SSHTunnel @@ -61,7 +61,7 @@ def test_database_get_shh_tunnel(session_with_data: Session) -> None: assert 1 == result.database_id -def test_database_get_shh_tunnel_not_found(session_with_data: Session) -> None: +def test_database_get_ssh_tunnel_not_found(session_with_data: Session) -> None: from superset.databases.dao import DatabaseDAO result = DatabaseDAO.get_ssh_tunnel(2) diff --git a/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py b/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py new file mode 100644 index 000000000000..2b6b8d9aaebb --- /dev/null +++ b/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Iterator + +import pytest +from sqlalchemy.orm.session import Session + + +def test_create_ssh_tunnel_command() -> None: + from superset.databases.ssh_tunnel.commands.create import CreateSSHTunnelCommand + from superset.databases.ssh_tunnel.models import SSHTunnel + from superset.models.core import Database + + db = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://") + + properties = { + "database_id": db.id, + "server_address": "123.132.123.1", + "server_port": "3005", + "username": "foo", + "password": "bar", + } + + result = CreateSSHTunnelCommand(db.id, properties).run() + + assert result is not None + assert isinstance(result, SSHTunnel) diff --git a/tests/unit_tests/databases/ssh_tunnel/commands/delete_test.py b/tests/unit_tests/databases/ssh_tunnel/commands/delete_test.py index d70dcd43d903..17afebfa0fec 100644 --- a/tests/unit_tests/databases/ssh_tunnel/commands/delete_test.py +++ b/tests/unit_tests/databases/ssh_tunnel/commands/delete_test.py @@ -50,7 +50,7 @@ def session_with_data(session: Session) -> Iterator[Session]: session.rollback() -def test_delete_shh_tunnel_command(session_with_data: Session) -> None: +def test_delete_ssh_tunnel_command(session_with_data: Session) -> None: from superset.databases.dao import DatabaseDAO from superset.databases.ssh_tunnel.commands.delete import DeleteSSHTunnelCommand from superset.databases.ssh_tunnel.models import SSHTunnel diff --git a/tests/unit_tests/databases/ssh_tunnel/dao_tests.py b/tests/unit_tests/databases/ssh_tunnel/dao_tests.py new file mode 100644 index 000000000000..ae5b6e9bd3c3 --- /dev/null +++ b/tests/unit_tests/databases/ssh_tunnel/dao_tests.py @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Iterator + +import pytest +from sqlalchemy.orm.session import Session + + +def test_create_ssh_tunnel(): + from superset.databases.dao import DatabaseDAO + from superset.databases.ssh_tunnel.dao import SSHTunnelDAO + from superset.databases.ssh_tunnel.models import SSHTunnel + from superset.models.core import Database + + db = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://") + + properties = { + "database_id": db.id, + "server_address": "123.132.123.1", + "server_port": "3005", + "username": "foo", + "password": "bar", + } + + result = SSHTunnelDAO.create(properties) + + assert result is not None + assert isinstance(result, SSHTunnel) From 466703a9fed25b48d97b9663b62fc437fa69ea03 Mon Sep 17 00:00:00 2001 From: "Hugh A. Miles II" Date: Mon, 28 Nov 2022 14:38:51 -0500 Subject: [PATCH 49/75] Update schemas.py --- superset/databases/schemas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/superset/databases/schemas.py b/superset/databases/schemas.py index fad2da0feee4..b874d147a22a 100644 --- a/superset/databases/schemas.py +++ b/superset/databases/schemas.py @@ -502,7 +502,7 @@ class DatabaseTestConnectionSchema(Schema, DatabaseParametersSchemaMixin): validate=[Length(1, 1024), sqlalchemy_uri_validator], ) - ssh_tunnel_credentials = fields.Nested(DatabaseSSHTunnel, allow_none=True) + ssh_tunnel = fields.Nested(DatabaseSSHTunnel, allow_none=True) class TableMetadataOptionsResponseSchema(Schema): From 44487396cef334d57bab8fd97100d943a90e0934 Mon Sep 17 00:00:00 2001 From: "Hugh A. Miles II" Date: Tue, 29 Nov 2022 19:16:36 -0500 Subject: [PATCH 50/75] chore(ssh-tunnel): create `contextmanager` for sql.inspect (#22251) --- .../databases/commands/test_connection.py | 14 ++- superset/databases/schemas.py | 4 - superset/models/core.py | 106 ++++++++++-------- tests/conftest.py | 5 +- 4 files changed, 71 insertions(+), 58 deletions(-) diff --git a/superset/databases/commands/test_connection.py b/superset/databases/commands/test_connection.py index 7b913ed20298..1fbf16efa70a 100644 --- a/superset/databases/commands/test_connection.py +++ b/superset/databases/commands/test_connection.py @@ -32,6 +32,7 @@ DatabaseTestConnectionUnexpectedError, ) from superset.databases.dao import DatabaseDAO +from superset.databases.ssh_tunnel.models import SSHTunnel from superset.databases.utils import make_url_safe from superset.errors import ErrorLevel, SupersetErrorType from superset.exceptions import ( @@ -90,10 +91,13 @@ def run(self) -> None: # pylint: disable=too-many-statements database.set_sqlalchemy_uri(uri) database.db_engine_spec.mutate_db_for_connection_test(database) - # TODO: (hughhh) uncomment in API enablement PR + # Generate tunnel if present in the properties ssh_tunnel = None - # if self._properties.get("ssh_tunnel"): - # ssh_tunnel = SSHTunnel(**self._properties["ssh_tunnel"]) + if ssh_tunnel := self._properties.get("ssh_tunnel"): + url = make_url_safe(database.sqlalchemy_uri_decrypted) + ssh_tunnel["bind_host"] = url.host + ssh_tunnel["bind_port"] = url.port + ssh_tunnel = SSHTunnel(**ssh_tunnel) event_logger.log_with_context( action="test_connection_attempt", @@ -104,7 +108,9 @@ def ping(engine: Engine) -> bool: with closing(engine.raw_connection()) as conn: return engine.dialect.do_ping(conn) - with database.get_sqla_engine_with_context(ssh_tunnel=ssh_tunnel) as engine: + with database.get_sqla_engine_with_context( + override_ssh_tunnel=ssh_tunnel + ) as engine: try: alive = func_timeout( app.config["TEST_DATABASE_CONNECTION_TIMEOUT"].total_seconds(), diff --git a/superset/databases/schemas.py b/superset/databases/schemas.py index b874d147a22a..2dcf5ac033e2 100644 --- a/superset/databases/schemas.py +++ b/superset/databases/schemas.py @@ -471,10 +471,6 @@ class DatabaseSSHTunnel(Schema): private_key = fields.String(required=False) private_key_password = fields.String(required=False) - # remote binding port - bind_host = fields.String() - bind_port = fields.Integer() - class DatabaseTestConnectionSchema(Schema, DatabaseParametersSchemaMixin): diff --git a/superset/models/core.py b/superset/models/core.py index fd6ad8be9347..095da48148ae 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -372,34 +372,31 @@ def get_sqla_engine_with_context( schema: Optional[str] = None, nullpool: bool = True, source: Optional[utils.QuerySource] = None, - ssh_tunnel: Optional["SSHTunnel"] = None, + override_ssh_tunnel: Optional["SSHTunnel"] = None, ) -> Engine: ssh_params = {} - if ssh_tunnel: - # build with override + from superset.databases.dao import ( # pylint: disable=import-outside-toplevel + DatabaseDAO, + ) + + if ssh_tunnel := override_ssh_tunnel or DatabaseDAO.get_ssh_tunnel( + database_id=self.id + ): + # if ssh_tunnel is available build engine with information url = make_url_safe(self.sqlalchemy_uri_decrypted) ssh_tunnel.bind_host = url.host ssh_tunnel.bind_port = url.port ssh_params = ssh_tunnel.parameters() - try: - with sshtunnel.open_tunnel(**ssh_params) as server: - yield self._get_sqla_engine( - schema=schema, - nullpool=nullpool, - source=source, - ssh_tunnel_server=server, - ) - except Exception as ex: - raise ex - - else: - # do look up in table for using database_id - try: + with sshtunnel.open_tunnel(**ssh_params) as server: yield self._get_sqla_engine( - schema=schema, nullpool=nullpool, source=source + schema=schema, + nullpool=nullpool, + source=source, + ssh_tunnel_server=server, ) - except Exception as ex: - raise ex + + else: + yield self._get_sqla_engine(schema=schema, nullpool=nullpool, source=source) def _get_sqla_engine( self, @@ -594,14 +591,16 @@ def get_all_table_names_in_schema( # pylint: disable=unused-argument :return: The table/schema pairs """ try: - return { - (table, schema) - for table in self.db_engine_spec.get_table_names( - database=self, - inspector=self.inspector, - schema=schema, - ) - } + with self.get_inspector_with_context() as inspector: + tables = { + (table, schema) + for table in self.db_engine_spec.get_table_names( + database=self, + inspector=inspector, + schema=schema, + ) + } + return tables except Exception as ex: raise self.db_engine_spec.get_dbapi_mapped_exception(ex) @@ -628,17 +627,23 @@ def get_all_view_names_in_schema( # pylint: disable=unused-argument :return: set of views """ try: - return { - (view, schema) - for view in self.db_engine_spec.get_view_names( - database=self, - inspector=self.inspector, - schema=schema, - ) - } + with self.get_inspector_with_context() as inspector: + return { + (view, schema) + for view in self.db_engine_spec.get_view_names( + database=self, + inspector=inspector, + schema=schema, + ) + } except Exception as ex: raise self.db_engine_spec.get_dbapi_mapped_exception(ex) + @contextmanager + def get_inspector_with_context(self): + with self.get_sqla_engine_with_context() as engine: + yield sqla.inspect(engine) + @cache_util.memoized_func( key="db:{self.id}:schema_list", cache=cache_manager.cache, @@ -660,7 +665,8 @@ def get_all_schema_names( # pylint: disable=unused-argument :return: schema list """ try: - return self.db_engine_spec.get_schema_names(self.inspector) + with self.get_inspector_with_context() as inspector: + return self.db_engine_spec.get_schema_names(inspector) except Exception as ex: raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex @@ -728,7 +734,8 @@ def get_table_comment( def get_columns( self, table_name: str, schema: Optional[str] = None ) -> List[Dict[str, Any]]: - return self.db_engine_spec.get_columns(self.inspector, table_name, schema) + with self.get_inspector_with_context() as inspector: + return self.db_engine_spec.get_columns(inspector, table_name, schema) def get_metrics( self, @@ -740,26 +747,29 @@ def get_metrics( def get_indexes( self, table_name: str, schema: Optional[str] = None ) -> List[Dict[str, Any]]: - indexes = self.inspector.get_indexes(table_name, schema) - return self.db_engine_spec.normalize_indexes(indexes) + with self.get_inspector_with_context() as inspector: + indexes = inspector.get_indexes(table_name, schema) + return self.db_engine_spec.normalize_indexes(indexes) def get_pk_constraint( self, table_name: str, schema: Optional[str] = None ) -> Dict[str, Any]: - pk_constraint = self.inspector.get_pk_constraint(table_name, schema) or {} + with self.get_inspector_with_context() as inspector: + pk_constraint = inspector.get_pk_constraint(table_name, schema) or {} - def _convert(value: Any) -> Any: - try: - return utils.base_json_conv(value) - except TypeError: - return None + def _convert(value: Any) -> Any: + try: + return utils.base_json_conv(value) + except TypeError: + return None - return {key: _convert(value) for key, value in pk_constraint.items()} + return {key: _convert(value) for key, value in pk_constraint.items()} def get_foreign_keys( self, table_name: str, schema: Optional[str] = None ) -> List[Dict[str, Any]]: - return self.inspector.get_foreign_keys(table_name, schema) + with self.get_inspector_with_context() as inspector: + return inspector.get_foreign_keys(table_name, schema) def get_schema_access_for_file_upload( # pylint: disable=invalid-name self, diff --git a/tests/conftest.py b/tests/conftest.py index a5945f2f5c4c..341c0102289e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -70,8 +70,9 @@ def mock_provider() -> Mock: @fixture(scope="session") def example_db_engine(example_db_provider: Callable[[], Database]) -> Engine: - with example_db_provider().get_sqla_engine_with_context() as engine: - return engine + with create_app().app_context(): + with example_db_provider().get_sqla_engine_with_context() as engine: + return engine @fixture(scope="session") From bb7805592db1fc44daf2e4219a63f0541fe51bdf Mon Sep 17 00:00:00 2001 From: hughhhh Date: Wed, 30 Nov 2022 00:33:07 -0500 Subject: [PATCH 51/75] fix lint --- superset/databases/dao.py | 2 +- superset/models/core.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/superset/databases/dao.py b/superset/databases/dao.py index 4e9ad1ee8c13..d5a58245d151 100644 --- a/superset/databases/dao.py +++ b/superset/databases/dao.py @@ -127,7 +127,7 @@ def get_related_objects(cls, database_id: int) -> Dict[str, Any]: ) @classmethod - def get_ssh_tunnel(cls, database_id: int) -> Dict[str, Any]: + def get_ssh_tunnel(cls, database_id: int) -> SSHTunnel: ssh_tunnel = ( db.session.query(SSHTunnel) .filter(SSHTunnel.database_id == database_id) diff --git a/superset/models/core.py b/superset/models/core.py index 095da48148ae..d94372577c2d 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -374,10 +374,10 @@ def get_sqla_engine_with_context( source: Optional[utils.QuerySource] = None, override_ssh_tunnel: Optional["SSHTunnel"] = None, ) -> Engine: - ssh_params = {} - from superset.databases.dao import ( # pylint: disable=import-outside-toplevel + ssh_params: Dict[str, Any] = {} + from superset.databases.dao import ( # pylint: disable=import-outside-toplevel DatabaseDAO, - ) + ) if ssh_tunnel := override_ssh_tunnel or DatabaseDAO.get_ssh_tunnel( database_id=self.id @@ -640,7 +640,7 @@ def get_all_view_names_in_schema( # pylint: disable=unused-argument raise self.db_engine_spec.get_dbapi_mapped_exception(ex) @contextmanager - def get_inspector_with_context(self): + def get_inspector_with_context(self) -> Inspector: with self.get_sqla_engine_with_context() as engine: yield sqla.inspect(engine) From f507385113b0437cedc07734d80be2ba91a16663 Mon Sep 17 00:00:00 2001 From: hughhhh Date: Wed, 30 Nov 2022 15:28:37 -0500 Subject: [PATCH 52/75] fix migrations --- ...20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py b/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py index 690d7cd5037a..9e3809ae5789 100644 --- a/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py +++ b/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py @@ -17,14 +17,14 @@ """create_ssh_tunnel_credentials_tbl Revision ID: f3c2d8ec8595 -Revises: deb4c9d4a4ef +Revises: 4ce1d9b25135 Create Date: 2022-10-20 10:48:08.722861 """ # revision identifiers, used by Alembic. revision = "f3c2d8ec8595" -down_revision = "deb4c9d4a4ef" +down_revision = "4ce1d9b25135" from uuid import uuid4 From 86436b6d010a3d30891dadf93216b631bf52134f Mon Sep 17 00:00:00 2001 From: hughhhh Date: Thu, 1 Dec 2022 16:43:28 -0500 Subject: [PATCH 53/75] Revert "chore(ssh-tunnel): create `contextmanager` for sql.inspect (#22251)" This reverts commit 44487396cef334d57bab8fd97100d943a90e0934. --- .../databases/commands/test_connection.py | 14 +-- superset/databases/schemas.py | 4 + superset/models/core.py | 108 ++++++++---------- tests/conftest.py | 5 +- 4 files changed, 59 insertions(+), 72 deletions(-) diff --git a/superset/databases/commands/test_connection.py b/superset/databases/commands/test_connection.py index 1fbf16efa70a..7b913ed20298 100644 --- a/superset/databases/commands/test_connection.py +++ b/superset/databases/commands/test_connection.py @@ -32,7 +32,6 @@ DatabaseTestConnectionUnexpectedError, ) from superset.databases.dao import DatabaseDAO -from superset.databases.ssh_tunnel.models import SSHTunnel from superset.databases.utils import make_url_safe from superset.errors import ErrorLevel, SupersetErrorType from superset.exceptions import ( @@ -91,13 +90,10 @@ def run(self) -> None: # pylint: disable=too-many-statements database.set_sqlalchemy_uri(uri) database.db_engine_spec.mutate_db_for_connection_test(database) - # Generate tunnel if present in the properties + # TODO: (hughhh) uncomment in API enablement PR ssh_tunnel = None - if ssh_tunnel := self._properties.get("ssh_tunnel"): - url = make_url_safe(database.sqlalchemy_uri_decrypted) - ssh_tunnel["bind_host"] = url.host - ssh_tunnel["bind_port"] = url.port - ssh_tunnel = SSHTunnel(**ssh_tunnel) + # if self._properties.get("ssh_tunnel"): + # ssh_tunnel = SSHTunnel(**self._properties["ssh_tunnel"]) event_logger.log_with_context( action="test_connection_attempt", @@ -108,9 +104,7 @@ def ping(engine: Engine) -> bool: with closing(engine.raw_connection()) as conn: return engine.dialect.do_ping(conn) - with database.get_sqla_engine_with_context( - override_ssh_tunnel=ssh_tunnel - ) as engine: + with database.get_sqla_engine_with_context(ssh_tunnel=ssh_tunnel) as engine: try: alive = func_timeout( app.config["TEST_DATABASE_CONNECTION_TIMEOUT"].total_seconds(), diff --git a/superset/databases/schemas.py b/superset/databases/schemas.py index 2dcf5ac033e2..b874d147a22a 100644 --- a/superset/databases/schemas.py +++ b/superset/databases/schemas.py @@ -471,6 +471,10 @@ class DatabaseSSHTunnel(Schema): private_key = fields.String(required=False) private_key_password = fields.String(required=False) + # remote binding port + bind_host = fields.String() + bind_port = fields.Integer() + class DatabaseTestConnectionSchema(Schema, DatabaseParametersSchemaMixin): diff --git a/superset/models/core.py b/superset/models/core.py index d94372577c2d..fd6ad8be9347 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -372,31 +372,34 @@ def get_sqla_engine_with_context( schema: Optional[str] = None, nullpool: bool = True, source: Optional[utils.QuerySource] = None, - override_ssh_tunnel: Optional["SSHTunnel"] = None, + ssh_tunnel: Optional["SSHTunnel"] = None, ) -> Engine: - ssh_params: Dict[str, Any] = {} - from superset.databases.dao import ( # pylint: disable=import-outside-toplevel - DatabaseDAO, - ) - - if ssh_tunnel := override_ssh_tunnel or DatabaseDAO.get_ssh_tunnel( - database_id=self.id - ): - # if ssh_tunnel is available build engine with information + ssh_params = {} + if ssh_tunnel: + # build with override url = make_url_safe(self.sqlalchemy_uri_decrypted) ssh_tunnel.bind_host = url.host ssh_tunnel.bind_port = url.port ssh_params = ssh_tunnel.parameters() - with sshtunnel.open_tunnel(**ssh_params) as server: - yield self._get_sqla_engine( - schema=schema, - nullpool=nullpool, - source=source, - ssh_tunnel_server=server, - ) + try: + with sshtunnel.open_tunnel(**ssh_params) as server: + yield self._get_sqla_engine( + schema=schema, + nullpool=nullpool, + source=source, + ssh_tunnel_server=server, + ) + except Exception as ex: + raise ex else: - yield self._get_sqla_engine(schema=schema, nullpool=nullpool, source=source) + # do look up in table for using database_id + try: + yield self._get_sqla_engine( + schema=schema, nullpool=nullpool, source=source + ) + except Exception as ex: + raise ex def _get_sqla_engine( self, @@ -591,16 +594,14 @@ def get_all_table_names_in_schema( # pylint: disable=unused-argument :return: The table/schema pairs """ try: - with self.get_inspector_with_context() as inspector: - tables = { - (table, schema) - for table in self.db_engine_spec.get_table_names( - database=self, - inspector=inspector, - schema=schema, - ) - } - return tables + return { + (table, schema) + for table in self.db_engine_spec.get_table_names( + database=self, + inspector=self.inspector, + schema=schema, + ) + } except Exception as ex: raise self.db_engine_spec.get_dbapi_mapped_exception(ex) @@ -627,23 +628,17 @@ def get_all_view_names_in_schema( # pylint: disable=unused-argument :return: set of views """ try: - with self.get_inspector_with_context() as inspector: - return { - (view, schema) - for view in self.db_engine_spec.get_view_names( - database=self, - inspector=inspector, - schema=schema, - ) - } + return { + (view, schema) + for view in self.db_engine_spec.get_view_names( + database=self, + inspector=self.inspector, + schema=schema, + ) + } except Exception as ex: raise self.db_engine_spec.get_dbapi_mapped_exception(ex) - @contextmanager - def get_inspector_with_context(self) -> Inspector: - with self.get_sqla_engine_with_context() as engine: - yield sqla.inspect(engine) - @cache_util.memoized_func( key="db:{self.id}:schema_list", cache=cache_manager.cache, @@ -665,8 +660,7 @@ def get_all_schema_names( # pylint: disable=unused-argument :return: schema list """ try: - with self.get_inspector_with_context() as inspector: - return self.db_engine_spec.get_schema_names(inspector) + return self.db_engine_spec.get_schema_names(self.inspector) except Exception as ex: raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex @@ -734,8 +728,7 @@ def get_table_comment( def get_columns( self, table_name: str, schema: Optional[str] = None ) -> List[Dict[str, Any]]: - with self.get_inspector_with_context() as inspector: - return self.db_engine_spec.get_columns(inspector, table_name, schema) + return self.db_engine_spec.get_columns(self.inspector, table_name, schema) def get_metrics( self, @@ -747,29 +740,26 @@ def get_metrics( def get_indexes( self, table_name: str, schema: Optional[str] = None ) -> List[Dict[str, Any]]: - with self.get_inspector_with_context() as inspector: - indexes = inspector.get_indexes(table_name, schema) - return self.db_engine_spec.normalize_indexes(indexes) + indexes = self.inspector.get_indexes(table_name, schema) + return self.db_engine_spec.normalize_indexes(indexes) def get_pk_constraint( self, table_name: str, schema: Optional[str] = None ) -> Dict[str, Any]: - with self.get_inspector_with_context() as inspector: - pk_constraint = inspector.get_pk_constraint(table_name, schema) or {} + pk_constraint = self.inspector.get_pk_constraint(table_name, schema) or {} - def _convert(value: Any) -> Any: - try: - return utils.base_json_conv(value) - except TypeError: - return None + def _convert(value: Any) -> Any: + try: + return utils.base_json_conv(value) + except TypeError: + return None - return {key: _convert(value) for key, value in pk_constraint.items()} + return {key: _convert(value) for key, value in pk_constraint.items()} def get_foreign_keys( self, table_name: str, schema: Optional[str] = None ) -> List[Dict[str, Any]]: - with self.get_inspector_with_context() as inspector: - return inspector.get_foreign_keys(table_name, schema) + return self.inspector.get_foreign_keys(table_name, schema) def get_schema_access_for_file_upload( # pylint: disable=invalid-name self, diff --git a/tests/conftest.py b/tests/conftest.py index 341c0102289e..a5945f2f5c4c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -70,9 +70,8 @@ def mock_provider() -> Mock: @fixture(scope="session") def example_db_engine(example_db_provider: Callable[[], Database]) -> Engine: - with create_app().app_context(): - with example_db_provider().get_sqla_engine_with_context() as engine: - return engine + with example_db_provider().get_sqla_engine_with_context() as engine: + return engine @fixture(scope="session") From 54a8d7f8a6ed9cc559bb919fe08f536913f969fb Mon Sep 17 00:00:00 2001 From: hughhhh Date: Thu, 1 Dec 2022 17:24:09 -0500 Subject: [PATCH 54/75] debugging --- .../databases/commands/test_connection.py | 14 ++- superset/databases/schemas.py | 4 - superset/models/core.py | 106 ++++++++++-------- tests/conftest.py | 8 +- 4 files changed, 74 insertions(+), 58 deletions(-) diff --git a/superset/databases/commands/test_connection.py b/superset/databases/commands/test_connection.py index 7b913ed20298..1fbf16efa70a 100644 --- a/superset/databases/commands/test_connection.py +++ b/superset/databases/commands/test_connection.py @@ -32,6 +32,7 @@ DatabaseTestConnectionUnexpectedError, ) from superset.databases.dao import DatabaseDAO +from superset.databases.ssh_tunnel.models import SSHTunnel from superset.databases.utils import make_url_safe from superset.errors import ErrorLevel, SupersetErrorType from superset.exceptions import ( @@ -90,10 +91,13 @@ def run(self) -> None: # pylint: disable=too-many-statements database.set_sqlalchemy_uri(uri) database.db_engine_spec.mutate_db_for_connection_test(database) - # TODO: (hughhh) uncomment in API enablement PR + # Generate tunnel if present in the properties ssh_tunnel = None - # if self._properties.get("ssh_tunnel"): - # ssh_tunnel = SSHTunnel(**self._properties["ssh_tunnel"]) + if ssh_tunnel := self._properties.get("ssh_tunnel"): + url = make_url_safe(database.sqlalchemy_uri_decrypted) + ssh_tunnel["bind_host"] = url.host + ssh_tunnel["bind_port"] = url.port + ssh_tunnel = SSHTunnel(**ssh_tunnel) event_logger.log_with_context( action="test_connection_attempt", @@ -104,7 +108,9 @@ def ping(engine: Engine) -> bool: with closing(engine.raw_connection()) as conn: return engine.dialect.do_ping(conn) - with database.get_sqla_engine_with_context(ssh_tunnel=ssh_tunnel) as engine: + with database.get_sqla_engine_with_context( + override_ssh_tunnel=ssh_tunnel + ) as engine: try: alive = func_timeout( app.config["TEST_DATABASE_CONNECTION_TIMEOUT"].total_seconds(), diff --git a/superset/databases/schemas.py b/superset/databases/schemas.py index b874d147a22a..2dcf5ac033e2 100644 --- a/superset/databases/schemas.py +++ b/superset/databases/schemas.py @@ -471,10 +471,6 @@ class DatabaseSSHTunnel(Schema): private_key = fields.String(required=False) private_key_password = fields.String(required=False) - # remote binding port - bind_host = fields.String() - bind_port = fields.Integer() - class DatabaseTestConnectionSchema(Schema, DatabaseParametersSchemaMixin): diff --git a/superset/models/core.py b/superset/models/core.py index fd6ad8be9347..961b5265c06a 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -372,34 +372,31 @@ def get_sqla_engine_with_context( schema: Optional[str] = None, nullpool: bool = True, source: Optional[utils.QuerySource] = None, - ssh_tunnel: Optional["SSHTunnel"] = None, + override_ssh_tunnel: Optional["SSHTunnel"] = None, ) -> Engine: ssh_params = {} - if ssh_tunnel: - # build with override + from superset.databases.dao import ( # pylint: disable=import-outside-toplevel + DatabaseDAO, + ) + + if ssh_tunnel := override_ssh_tunnel or DatabaseDAO.get_ssh_tunnel( + database_id=self.id + ): + # if ssh_tunnel is available build engine with information url = make_url_safe(self.sqlalchemy_uri_decrypted) ssh_tunnel.bind_host = url.host ssh_tunnel.bind_port = url.port ssh_params = ssh_tunnel.parameters() - try: - with sshtunnel.open_tunnel(**ssh_params) as server: - yield self._get_sqla_engine( - schema=schema, - nullpool=nullpool, - source=source, - ssh_tunnel_server=server, - ) - except Exception as ex: - raise ex - - else: - # do look up in table for using database_id - try: + with sshtunnel.open_tunnel(**ssh_params) as server: yield self._get_sqla_engine( - schema=schema, nullpool=nullpool, source=source + schema=schema, + nullpool=nullpool, + source=source, + ssh_tunnel_server=server, ) - except Exception as ex: - raise ex + + else: + yield self._get_sqla_engine(schema=schema, nullpool=nullpool, source=source) def _get_sqla_engine( self, @@ -594,14 +591,16 @@ def get_all_table_names_in_schema( # pylint: disable=unused-argument :return: The table/schema pairs """ try: - return { - (table, schema) - for table in self.db_engine_spec.get_table_names( - database=self, - inspector=self.inspector, - schema=schema, - ) - } + with self.get_inspector_with_context() as inspector: + tables = { + (table, schema) + for table in self.db_engine_spec.get_table_names( + database=self, + inspector=inspector, + schema=schema, + ) + } + return tables except Exception as ex: raise self.db_engine_spec.get_dbapi_mapped_exception(ex) @@ -628,17 +627,23 @@ def get_all_view_names_in_schema( # pylint: disable=unused-argument :return: set of views """ try: - return { - (view, schema) - for view in self.db_engine_spec.get_view_names( - database=self, - inspector=self.inspector, - schema=schema, - ) - } + with self.get_inspector_with_context() as inspector: + return { + (view, schema) + for view in self.db_engine_spec.get_view_names( + database=self, + inspector=inspector, + schema=schema, + ) + } except Exception as ex: raise self.db_engine_spec.get_dbapi_mapped_exception(ex) + @contextmanager + def get_inspector_with_context(self) -> Inspector: + with self.get_sqla_engine_with_context() as engine: + yield sqla.inspect(engine) + @cache_util.memoized_func( key="db:{self.id}:schema_list", cache=cache_manager.cache, @@ -660,7 +665,8 @@ def get_all_schema_names( # pylint: disable=unused-argument :return: schema list """ try: - return self.db_engine_spec.get_schema_names(self.inspector) + with self.get_inspector_with_context() as inspector: + return self.db_engine_spec.get_schema_names(inspector) except Exception as ex: raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex @@ -728,7 +734,8 @@ def get_table_comment( def get_columns( self, table_name: str, schema: Optional[str] = None ) -> List[Dict[str, Any]]: - return self.db_engine_spec.get_columns(self.inspector, table_name, schema) + with self.get_inspector_with_context() as inspector: + return self.db_engine_spec.get_columns(inspector, table_name, schema) def get_metrics( self, @@ -740,26 +747,29 @@ def get_metrics( def get_indexes( self, table_name: str, schema: Optional[str] = None ) -> List[Dict[str, Any]]: - indexes = self.inspector.get_indexes(table_name, schema) - return self.db_engine_spec.normalize_indexes(indexes) + with self.get_inspector_with_context() as inspector: + indexes = inspector.get_indexes(table_name, schema) + return self.db_engine_spec.normalize_indexes(indexes) def get_pk_constraint( self, table_name: str, schema: Optional[str] = None ) -> Dict[str, Any]: - pk_constraint = self.inspector.get_pk_constraint(table_name, schema) or {} + with self.get_inspector_with_context() as inspector: + pk_constraint = inspector.get_pk_constraint(table_name, schema) or {} - def _convert(value: Any) -> Any: - try: - return utils.base_json_conv(value) - except TypeError: - return None + def _convert(value: Any) -> Any: + try: + return utils.base_json_conv(value) + except TypeError: + return None - return {key: _convert(value) for key, value in pk_constraint.items()} + return {key: _convert(value) for key, value in pk_constraint.items()} def get_foreign_keys( self, table_name: str, schema: Optional[str] = None ) -> List[Dict[str, Any]]: - return self.inspector.get_foreign_keys(table_name, schema) + with self.get_inspector_with_context() as inspector: + return inspector.get_foreign_keys(table_name, schema) def get_schema_access_for_file_upload( # pylint: disable=invalid-name self, diff --git a/tests/conftest.py b/tests/conftest.py index a5945f2f5c4c..542b74093b68 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -29,6 +29,7 @@ from unittest.mock import MagicMock, Mock, PropertyMock from flask import Flask +from flask import current_app from flask.ctx import AppContext from pytest import fixture @@ -41,6 +42,8 @@ TableToDfConvertorImpl, ) +from tests.integration_tests.test_app import app + SUPPORT_DATETIME_TYPE = "support_datetime_type" if TYPE_CHECKING: @@ -70,8 +73,9 @@ def mock_provider() -> Mock: @fixture(scope="session") def example_db_engine(example_db_provider: Callable[[], Database]) -> Engine: - with example_db_provider().get_sqla_engine_with_context() as engine: - return engine + with app.app_context(): + with example_db_provider().get_sqla_engine_with_context() as engine: + return engine @fixture(scope="session") From 3f6afec95a8bf10a151a23218fb03d67c49d49eb Mon Sep 17 00:00:00 2001 From: Antonio Rivero Martinez <38889534+Antonio-RiveroMartnez@users.noreply.github.com> Date: Mon, 5 Dec 2022 15:02:45 -0300 Subject: [PATCH 55/75] fix(ssh_tunnel): Address Base PR comments from peer review (#22306) --- setup.py | 2 +- superset/constants.py | 2 ++ .../databases/ssh_tunnel/commands/__init__.py | 16 +++++++++++++ .../databases/ssh_tunnel/commands/create.py | 4 +--- .../databases/ssh_tunnel/commands/update.py | 5 +--- superset/databases/ssh_tunnel/models.py | 23 +++++++++++-------- ...c8595_create_ssh_tunnel_credentials_tbl.py | 16 ++++++------- superset/models/core.py | 5 ++-- 8 files changed, 45 insertions(+), 28 deletions(-) create mode 100644 superset/databases/ssh_tunnel/commands/__init__.py diff --git a/setup.py b/setup.py index 67828bc13974..0fbaf33a8ab0 100644 --- a/setup.py +++ b/setup.py @@ -113,7 +113,7 @@ def get_git_sha() -> str: "PyJWT>=2.4.0, <3.0", "redis", "selenium>=3.141.0", - "sshtunnel>=0.4.0", + "sshtunnel>=0.4.0, <0.5", "simplejson>=3.15.0", "slack_sdk>=3.1.1, <4", "sqlalchemy>=1.4, <2", diff --git a/superset/constants.py b/superset/constants.py index 7d759acf6741..c0fbb7c2cd8d 100644 --- a/superset/constants.py +++ b/superset/constants.py @@ -34,6 +34,8 @@ NO_TIME_RANGE = "No filter" +SSH_TUNNELLING_LOCAL_BIND_ADDRESS = "127.0.0.1" + class RouteMethod: # pylint: disable=too-few-public-methods """ diff --git a/superset/databases/ssh_tunnel/commands/__init__.py b/superset/databases/ssh_tunnel/commands/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/superset/databases/ssh_tunnel/commands/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/superset/databases/ssh_tunnel/commands/create.py b/superset/databases/ssh_tunnel/commands/create.py index 1185d20bdf7c..29ee6c12471b 100644 --- a/superset/databases/ssh_tunnel/commands/create.py +++ b/superset/databases/ssh_tunnel/commands/create.py @@ -15,14 +15,12 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, Dict, List, Optional +from typing import Any, Dict from flask_appbuilder.models.sqla import Model -from marshmallow import ValidationError from superset.commands.base import BaseCommand from superset.dao.exceptions import DAOCreateFailedError -from superset.databases.dao import DatabaseDAO from superset.databases.ssh_tunnel.commands.exceptions import SSHTunnelCreateFailedError from superset.databases.ssh_tunnel.dao import SSHTunnelDAO diff --git a/superset/databases/ssh_tunnel/commands/update.py b/superset/databases/ssh_tunnel/commands/update.py index a9d9a6fe6d18..fd73c7b3ddf0 100644 --- a/superset/databases/ssh_tunnel/commands/update.py +++ b/superset/databases/ssh_tunnel/commands/update.py @@ -15,16 +15,13 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Optional from flask_appbuilder.models.sqla import Model -from marshmallow import ValidationError from superset.commands.base import BaseCommand from superset.dao.exceptions import DAOUpdateFailedError from superset.databases.ssh_tunnel.commands.exceptions import ( - SSHTunnelDeleteFailedError, - SSHTunnelInvalidError, SSHTunnelNotFoundError, SSHTunnelUpdateFailedError, ) diff --git a/superset/databases/ssh_tunnel/models.py b/superset/databases/ssh_tunnel/models.py index f7c8c9618106..f3bcd303d9fb 100644 --- a/superset/databases/ssh_tunnel/models.py +++ b/superset/databases/ssh_tunnel/models.py @@ -18,11 +18,12 @@ from typing import Any, Dict import sqlalchemy as sa +from flask import current_app from flask_appbuilder import Model from sqlalchemy.orm import backref, relationship from sqlalchemy_utils import EncryptedType -from superset import app +from superset.constants import SSH_TUNNELLING_LOCAL_BIND_ADDRESS from superset.models.core import Database from superset.models.helpers import ( AuditMixinNullable, @@ -30,7 +31,7 @@ ImportExportMixin, ) -app_config = app.config +app_config = current_app.config class SSHTunnel(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin): @@ -41,15 +42,17 @@ class SSHTunnel(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin): __tablename__ = "ssh_tunnels" id = sa.Column(sa.Integer, primary_key=True) - database_id = sa.Column(sa.Integer, sa.ForeignKey("dbs.id"), nullable=False) + database_id = sa.Column( + sa.Integer, sa.ForeignKey("dbs.id"), nullable=False, unique=True + ) database: Database = relationship( "Database", backref=backref("ssh_tunnels", cascade="all, delete-orphan"), foreign_keys=[database_id], ) - server_address = sa.Column(EncryptedType(sa.String, app_config["SECRET_KEY"])) - server_port = sa.Column(EncryptedType(sa.Integer, app_config["SECRET_KEY"])) + server_address = sa.Column(sa.Text) + server_port = sa.Column(sa.Integer) username = sa.Column(EncryptedType(sa.String, app_config["SECRET_KEY"])) # basic authentication @@ -65,8 +68,8 @@ class SSHTunnel(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin): EncryptedType(sa.String, app_config["SECRET_KEY"]), nullable=True ) - bind_host = sa.Column(EncryptedType(sa.String, app_config["SECRET_KEY"])) - bind_port = sa.Column(EncryptedType(sa.Integer, app_config["SECRET_KEY"])) + bind_host = sa.Column(sa.Text) + bind_port = sa.Column(sa.Integer) def parameters(self) -> Dict[str, Any]: params = { @@ -74,13 +77,13 @@ def parameters(self) -> Dict[str, Any]: "ssh_port": self.server_port, "ssh_username": self.username, "remote_bind_address": (self.bind_host, self.bind_port), - "local_bind_address": ("127.0.0.1",), + "local_bind_address": (SSH_TUNNELLING_LOCAL_BIND_ADDRESS,), } if self.password: params["ssh_password"] = self.password elif self.private_key: - params["ssh_pkey"] = self.private_key - params["ssh_private_key_password"] = self.private_key_password + params["private_key"] = self.private_key + params["private_key_password"] = self.private_key_password return params diff --git a/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py b/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py index 9e3809ae5789..b90ccae50f64 100644 --- a/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py +++ b/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py @@ -52,12 +52,12 @@ def upgrade(): sa.Column("uuid", UUIDType(binary=True), primary_key=False, default=uuid4), # SSHTunnelCredentials sa.Column("id", sa.Integer(), primary_key=True), - sa.Column("database_id", sa.INTEGER(), sa.ForeignKey("dbs.id")), - sa.Column("server_address", encrypted_field_factory.create(sa.String(1024))), - sa.Column("server_port", encrypted_field_factory.create(sa.INTEGER())), - sa.Column("username", encrypted_field_factory.create(sa.String(1024))), + sa.Column("database_id", sa.INTEGER(), sa.ForeignKey("dbs.id"), unique=True), + sa.Column("server_address", sa.String(256)), + sa.Column("server_port", sa.INTEGER()), + sa.Column("username", encrypted_field_factory.create(sa.String(256))), sa.Column( - "password", encrypted_field_factory.create(sa.String(1024)), nullable=True + "password", encrypted_field_factory.create(sa.String(256)), nullable=True ), sa.Column( "private_key", @@ -66,11 +66,11 @@ def upgrade(): ), sa.Column( "private_key_password", - encrypted_field_factory.create(sa.String(1024)), + encrypted_field_factory.create(sa.String(256)), nullable=True, ), - sa.Column("bind_host", encrypted_field_factory.create(sa.String(1024))), - sa.Column("bind_port", encrypted_field_factory.create(sa.INTEGER())), + sa.Column("bind_host", sa.String(256)), + sa.Column("bind_port", sa.INTEGER()), ) diff --git a/superset/models/core.py b/superset/models/core.py index 961b5265c06a..309c4444a652 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -55,7 +55,7 @@ from sqlalchemy.sql import expression, Select from superset import app, db_engine_specs -from superset.constants import PASSWORD_MASK +from superset.constants import PASSWORD_MASK, SSH_TUNNELLING_LOCAL_BIND_ADDRESS from superset.databases.utils import make_url_safe from superset.db_engine_specs.base import MetricType, TimeGrain from superset.extensions import cache_manager, encrypted_field_factory, security_manager @@ -451,7 +451,8 @@ def _get_sqla_engine( # update sqlalchemy_url url = make_url_safe(sqlalchemy_url) sqlalchemy_url = url.set( - host="127.0.0.1", port=ssh_tunnel_server.local_bind_port + host=SSH_TUNNELLING_LOCAL_BIND_ADDRESS, + port=ssh_tunnel_server.local_bind_port, ) try: From 7625566415b26096f481ed3638ceecbba52fc936 Mon Sep 17 00:00:00 2001 From: hughhhh Date: Mon, 5 Dec 2022 13:47:52 -0500 Subject: [PATCH 56/75] fix pre-commit --- tests/conftest.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 542b74093b68..9d13e581704e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -28,8 +28,7 @@ from typing import Callable, TYPE_CHECKING from unittest.mock import MagicMock, Mock, PropertyMock -from flask import Flask -from flask import current_app +from flask import current_app, Flask from flask.ctx import AppContext from pytest import fixture @@ -41,7 +40,6 @@ from tests.example_data.data_loading.pandas.table_df_convertor import ( TableToDfConvertorImpl, ) - from tests.integration_tests.test_app import app SUPPORT_DATETIME_TYPE = "support_datetime_type" From 0578a8e54db0d7620e9c329cb64d9b46477c4d3d Mon Sep 17 00:00:00 2001 From: hughhhh Date: Tue, 6 Dec 2022 14:14:03 -0500 Subject: [PATCH 57/75] working changes --- superset/databases/commands/create.py | 11 +++++++ .../databases/commands/test_connection.py | 1 - superset/databases/dao.py | 2 +- superset/databases/schemas.py | 33 ++++++++++--------- 4 files changed, 29 insertions(+), 18 deletions(-) diff --git a/superset/databases/commands/create.py b/superset/databases/commands/create.py index 4dc8e8eda49e..ef2f555d61fc 100644 --- a/superset/databases/commands/create.py +++ b/superset/databases/commands/create.py @@ -31,6 +31,7 @@ ) from superset.databases.commands.test_connection import TestConnectionDatabaseCommand from superset.databases.dao import DatabaseDAO +from superset.databases.ssh_tunnel.dao import SSHTunnelDAO from superset.exceptions import SupersetErrorsException from superset.extensions import db, event_logger, security_manager @@ -77,6 +78,16 @@ def run(self) -> Model: security_manager.add_permission_view_menu( "schema_access", security_manager.get_schema_perm(database, schema) ) + + if ssh_tunnel_properties := self._properties.get("ssh_tunnel"): + SSHTunnelDAO.create( + { + **ssh_tunnel_properties, + "database_id": database.id, + }, + commit=False, + ) + db.session.commit() except DAOCreateFailedError as ex: db.session.rollback() diff --git a/superset/databases/commands/test_connection.py b/superset/databases/commands/test_connection.py index 1fbf16efa70a..ac1d12b78caf 100644 --- a/superset/databases/commands/test_connection.py +++ b/superset/databases/commands/test_connection.py @@ -92,7 +92,6 @@ def run(self) -> None: # pylint: disable=too-many-statements database.db_engine_spec.mutate_db_for_connection_test(database) # Generate tunnel if present in the properties - ssh_tunnel = None if ssh_tunnel := self._properties.get("ssh_tunnel"): url = make_url_safe(database.sqlalchemy_uri_decrypted) ssh_tunnel["bind_host"] = url.host diff --git a/superset/databases/dao.py b/superset/databases/dao.py index d5a58245d151..c82f0db5745a 100644 --- a/superset/databases/dao.py +++ b/superset/databases/dao.py @@ -127,7 +127,7 @@ def get_related_objects(cls, database_id: int) -> Dict[str, Any]: ) @classmethod - def get_ssh_tunnel(cls, database_id: int) -> SSHTunnel: + def get_ssh_tunnel(cls, database_id: int) -> Optional[SSHTunnel]: ssh_tunnel = ( db.session.query(SSHTunnel) .filter(SSHTunnel.database_id == database_id) diff --git a/superset/databases/schemas.py b/superset/databases/schemas.py index 2dcf5ac033e2..6dd7d0d79fcf 100644 --- a/superset/databases/schemas.py +++ b/superset/databases/schemas.py @@ -365,6 +365,22 @@ class Meta: # pylint: disable=too-few-public-methods ) +class DatabaseSSHTunnel(Schema): + id = fields.Integer() + database_id = fields.Integer() + + server_address = fields.String() + server_port = fields.Integer() + username = fields.String() + + # Basic Authentication + password = fields.String(required=False) + + # password protected private key authentication + private_key = fields.String(required=False) + private_key_password = fields.String(required=False) + + class DatabasePostSchema(Schema, DatabaseParametersSchemaMixin): class Meta: # pylint: disable=too-few-public-methods unknown = EXCLUDE @@ -409,6 +425,7 @@ class Meta: # pylint: disable=too-few-public-methods is_managed_externally = fields.Boolean(allow_none=True, default=False) external_url = fields.String(allow_none=True) uuid = fields.String(required=False) + ssh_tunnel = fields.Nested(DatabaseSSHTunnel, allow_none=True) class DatabasePutSchema(Schema, DatabaseParametersSchemaMixin): @@ -456,22 +473,6 @@ class Meta: # pylint: disable=too-few-public-methods external_url = fields.String(allow_none=True) -class DatabaseSSHTunnel(Schema): - id = fields.Integer() - database_id = fields.Integer() - - server_address = fields.String() - server_port = fields.Integer() - username = fields.String() - - # Basic Authentication - password = fields.String(required=False) - - # password protected private key authentication - private_key = fields.String(required=False) - private_key_password = fields.String(required=False) - - class DatabaseTestConnectionSchema(Schema, DatabaseParametersSchemaMixin): rename_encrypted_extra = pre_load(rename_encrypted_extra) From ec20429a84feba863d762e4160c14d516951d40b Mon Sep 17 00:00:00 2001 From: hughhhh Date: Tue, 6 Dec 2022 15:10:37 -0500 Subject: [PATCH 58/75] refactor bind_host and bind_port --- .../databases/commands/test_connection.py | 3 --- superset/databases/ssh_tunnel/models.py | 7 ++---- ...c8595_create_ssh_tunnel_credentials_tbl.py | 2 -- superset/models/core.py | 25 +++++++++---------- 4 files changed, 14 insertions(+), 23 deletions(-) diff --git a/superset/databases/commands/test_connection.py b/superset/databases/commands/test_connection.py index ac1d12b78caf..8027efcb49a4 100644 --- a/superset/databases/commands/test_connection.py +++ b/superset/databases/commands/test_connection.py @@ -93,9 +93,6 @@ def run(self) -> None: # pylint: disable=too-many-statements # Generate tunnel if present in the properties if ssh_tunnel := self._properties.get("ssh_tunnel"): - url = make_url_safe(database.sqlalchemy_uri_decrypted) - ssh_tunnel["bind_host"] = url.host - ssh_tunnel["bind_port"] = url.port ssh_tunnel = SSHTunnel(**ssh_tunnel) event_logger.log_with_context( diff --git a/superset/databases/ssh_tunnel/models.py b/superset/databases/ssh_tunnel/models.py index f3bcd303d9fb..d4ca3504cdf8 100644 --- a/superset/databases/ssh_tunnel/models.py +++ b/superset/databases/ssh_tunnel/models.py @@ -68,15 +68,12 @@ class SSHTunnel(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin): EncryptedType(sa.String, app_config["SECRET_KEY"]), nullable=True ) - bind_host = sa.Column(sa.Text) - bind_port = sa.Column(sa.Integer) - - def parameters(self) -> Dict[str, Any]: + def parameters(self, bind_host: str, bind_port: int) -> Dict[str, Any]: params = { "ssh_address_or_host": self.server_address, "ssh_port": self.server_port, "ssh_username": self.username, - "remote_bind_address": (self.bind_host, self.bind_port), + "remote_bind_address": (bind_host, bind_port), "local_bind_address": (SSH_TUNNELLING_LOCAL_BIND_ADDRESS,), } diff --git a/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py b/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py index b90ccae50f64..75bad1e53ed5 100644 --- a/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py +++ b/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py @@ -69,8 +69,6 @@ def upgrade(): encrypted_field_factory.create(sa.String(256)), nullable=True, ), - sa.Column("bind_host", sa.String(256)), - sa.Column("bind_port", sa.INTEGER()), ) diff --git a/superset/models/core.py b/superset/models/core.py index 309c4444a652..77d9170c427e 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -21,7 +21,7 @@ import logging import textwrap from ast import literal_eval -from contextlib import closing, contextmanager +from contextlib import closing, contextmanager, nullcontext from copy import deepcopy from datetime import datetime from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TYPE_CHECKING @@ -384,19 +384,18 @@ def get_sqla_engine_with_context( ): # if ssh_tunnel is available build engine with information url = make_url_safe(self.sqlalchemy_uri_decrypted) - ssh_tunnel.bind_host = url.host - ssh_tunnel.bind_port = url.port - ssh_params = ssh_tunnel.parameters() - with sshtunnel.open_tunnel(**ssh_params) as server: - yield self._get_sqla_engine( - schema=schema, - nullpool=nullpool, - source=source, - ssh_tunnel_server=server, - ) - + ssh_params = ssh_tunnel.parameters(bind_host=url.host, bind_port=url.port) + engine_context = sshtunnel.open_tunnel(**ssh_params) else: - yield self._get_sqla_engine(schema=schema, nullpool=nullpool, source=source) + engine_context = nullcontext() + + with engine_context as server_context: + yield self._get_sqla_engine( + schema=schema, + nullpool=nullpool, + source=source, + ssh_tunnel_server=server_context, + ) def _get_sqla_engine( self, From 1f57d4afcbe877fca2637c97b345c6bc8d9f63c2 Mon Sep 17 00:00:00 2001 From: hughhhh Date: Wed, 7 Dec 2022 13:23:57 -0500 Subject: [PATCH 59/75] refactor create flow for temp ssh tunnels --- superset/databases/commands/create.py | 17 +++++++++-------- superset/models/core.py | 12 +++++++++--- 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/superset/databases/commands/create.py b/superset/databases/commands/create.py index ef2f555d61fc..8b3a5f55bafc 100644 --- a/superset/databases/commands/create.py +++ b/superset/databases/commands/create.py @@ -72,15 +72,9 @@ def run(self) -> Model: database = DatabaseDAO.create(self._properties, commit=False) database.set_sqlalchemy_uri(database.sqlalchemy_uri) - # adding a new database we always want to force refresh schema list - schemas = database.get_all_schema_names(cache=False) - for schema in schemas: - security_manager.add_permission_view_menu( - "schema_access", security_manager.get_schema_perm(database, schema) - ) - + ssh_tunnel = None if ssh_tunnel_properties := self._properties.get("ssh_tunnel"): - SSHTunnelDAO.create( + ssh_tunnel = SSHTunnelDAO.create( { **ssh_tunnel_properties, "database_id": database.id, @@ -88,6 +82,13 @@ def run(self) -> Model: commit=False, ) + # adding a new database we always want to force refresh schema list + schemas = database.get_all_schema_names(cache=False, ssh_tunnel=ssh_tunnel) + for schema in schemas: + security_manager.add_permission_view_menu( + "schema_access", security_manager.get_schema_perm(database, schema) + ) + db.session.commit() except DAOCreateFailedError as ex: db.session.rollback() diff --git a/superset/models/core.py b/superset/models/core.py index 77d9170c427e..2826d4e3a738 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -383,6 +383,7 @@ def get_sqla_engine_with_context( database_id=self.id ): # if ssh_tunnel is available build engine with information + logger.info("Creating ssh tunnel for db: %", self.id) url = make_url_safe(self.sqlalchemy_uri_decrypted) ssh_params = ssh_tunnel.parameters(bind_host=url.host, bind_port=url.port) engine_context = sshtunnel.open_tunnel(**ssh_params) @@ -640,8 +641,12 @@ def get_all_view_names_in_schema( # pylint: disable=unused-argument raise self.db_engine_spec.get_dbapi_mapped_exception(ex) @contextmanager - def get_inspector_with_context(self) -> Inspector: - with self.get_sqla_engine_with_context() as engine: + def get_inspector_with_context( + self, ssh_tunnel: Optional["SSHTunnel"] = None + ) -> Inspector: + with self.get_sqla_engine_with_context( + override_ssh_tunnel=ssh_tunnel + ) as engine: yield sqla.inspect(engine) @cache_util.memoized_func( @@ -653,6 +658,7 @@ def get_all_schema_names( # pylint: disable=unused-argument cache: bool = False, cache_timeout: Optional[int] = None, force: bool = False, + ssh_tunnel: Optional["SSHTunnel"] = None, ) -> List[str]: """Parameters need to be passed as keyword arguments. @@ -665,7 +671,7 @@ def get_all_schema_names( # pylint: disable=unused-argument :return: schema list """ try: - with self.get_inspector_with_context() as inspector: + with self.get_inspector_with_context(ssh_tunnel=ssh_tunnel) as inspector: return self.db_engine_spec.get_schema_names(inspector) except Exception as ex: raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex From ed19a3eae3ac541550f9543ec50adc62fcf12714 Mon Sep 17 00:00:00 2001 From: hughhhh Date: Thu, 8 Dec 2022 10:38:14 -0500 Subject: [PATCH 60/75] remove logger --- superset/models/core.py | 1 - 1 file changed, 1 deletion(-) diff --git a/superset/models/core.py b/superset/models/core.py index 2826d4e3a738..0694d97fac81 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -383,7 +383,6 @@ def get_sqla_engine_with_context( database_id=self.id ): # if ssh_tunnel is available build engine with information - logger.info("Creating ssh tunnel for db: %", self.id) url = make_url_safe(self.sqlalchemy_uri_decrypted) ssh_params = ssh_tunnel.parameters(bind_host=url.host, bind_port=url.port) engine_context = sshtunnel.open_tunnel(**ssh_params) From 852c8bbca536d2cefc6f19f2088173a5514af847 Mon Sep 17 00:00:00 2001 From: Antonio Rivero Martinez <38889534+Antonio-RiveroMartnez@users.noreply.github.com> Date: Thu, 8 Dec 2022 18:50:22 -0300 Subject: [PATCH 61/75] chore(ssh_tunnel): Add extra tests to SSHTunnel commands (#22372) --- .../databases/ssh_tunnel/commands/create.py | 40 +++++++++- .../ssh_tunnel/commands/exceptions.py | 9 +++ .../databases/ssh_tunnel/commands/update.py | 10 +++ .../databases/ssh_tunnel/__init__.py | 16 ++++ .../databases/ssh_tunnel/commands/__init__.py | 16 ++++ .../ssh_tunnel/commands/commands_tests.py | 76 +++++++++++++++++++ .../ssh_tunnel/commands/create_test.py | 26 +++++++ .../ssh_tunnel/commands/update_test.py | 24 ++++++ 8 files changed, 214 insertions(+), 3 deletions(-) create mode 100644 tests/integration_tests/databases/ssh_tunnel/__init__.py create mode 100644 tests/integration_tests/databases/ssh_tunnel/commands/__init__.py create mode 100644 tests/integration_tests/databases/ssh_tunnel/commands/commands_tests.py diff --git a/superset/databases/ssh_tunnel/commands/create.py b/superset/databases/ssh_tunnel/commands/create.py index 29ee6c12471b..b2e62f340b0b 100644 --- a/superset/databases/ssh_tunnel/commands/create.py +++ b/superset/databases/ssh_tunnel/commands/create.py @@ -15,14 +15,20 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, Dict +from typing import Any, Dict, List, Optional from flask_appbuilder.models.sqla import Model +from marshmallow import ValidationError from superset.commands.base import BaseCommand from superset.dao.exceptions import DAOCreateFailedError -from superset.databases.ssh_tunnel.commands.exceptions import SSHTunnelCreateFailedError +from superset.databases.ssh_tunnel.commands.exceptions import ( + SSHTunnelCreateFailedError, + SSHTunnelInvalidError, + SSHTunnelRequiredFieldValidationError, +) from superset.databases.ssh_tunnel.dao import SSHTunnelDAO +from superset.extensions import event_logger logger = logging.getLogger(__name__) @@ -45,4 +51,32 @@ def run(self) -> Model: def validate(self) -> None: # TODO(hughhh): check to make sure the server port is not localhost # using the config.SSH_TUNNEL_MANAGER - return + exceptions: List[ValidationError] = [] + database_id: Optional[int] = self._properties.get("database_id") + server_address: Optional[str] = self._properties.get("server_address") + server_port: Optional[int] = self._properties.get("server_port") + username: Optional[str] = self._properties.get("username") + private_key: Optional[str] = self._properties.get("private_key") + private_key_password: Optional[str] = self._properties.get( + "private_key_password" + ) + if not database_id: + exceptions.append(SSHTunnelRequiredFieldValidationError("database_id")) + if not server_address: + exceptions.append(SSHTunnelRequiredFieldValidationError("server_address")) + if not server_port: + exceptions.append(SSHTunnelRequiredFieldValidationError("server_port")) + if not username: + exceptions.append(SSHTunnelRequiredFieldValidationError("username")) + if private_key_password and private_key is None: + exceptions.append(SSHTunnelRequiredFieldValidationError("private_key")) + if exceptions: + exception = SSHTunnelInvalidError() + exception.add_list(exceptions) + event_logger.log_with_context( + action="ssh_tunnel_creation_failed.{}.{}".format( + exception.__class__.__name__, + ".".join(exception.get_list_classnames()), + ) + ) + raise exception diff --git a/superset/databases/ssh_tunnel/commands/exceptions.py b/superset/databases/ssh_tunnel/commands/exceptions.py index 9e3bce81a64f..db2d3173de01 100644 --- a/superset/databases/ssh_tunnel/commands/exceptions.py +++ b/superset/databases/ssh_tunnel/commands/exceptions.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. from flask_babel import lazy_gettext as _ +from marshmallow import ValidationError from superset.commands.exceptions import ( CommandException, @@ -43,3 +44,11 @@ class SSHTunnelUpdateFailedError(UpdateFailedError): class SSHTunnelCreateFailedError(CommandException): message = _("Creating SSH Tunnel failed for an unknown reason") + + +class SSHTunnelRequiredFieldValidationError(ValidationError): + def __init__(self, field_name: str) -> None: + super().__init__( + [_("Field is required")], + field_name=field_name, + ) diff --git a/superset/databases/ssh_tunnel/commands/update.py b/superset/databases/ssh_tunnel/commands/update.py index fd73c7b3ddf0..8d2feaf1b0d5 100644 --- a/superset/databases/ssh_tunnel/commands/update.py +++ b/superset/databases/ssh_tunnel/commands/update.py @@ -22,7 +22,9 @@ from superset.commands.base import BaseCommand from superset.dao.exceptions import DAOUpdateFailedError from superset.databases.ssh_tunnel.commands.exceptions import ( + SSHTunnelInvalidError, SSHTunnelNotFoundError, + SSHTunnelRequiredFieldValidationError, SSHTunnelUpdateFailedError, ) from superset.databases.ssh_tunnel.dao import SSHTunnelDAO @@ -50,3 +52,11 @@ def validate(self) -> None: self._model = SSHTunnelDAO.find_by_id(self._model_id) if not self._model: raise SSHTunnelNotFoundError() + private_key: Optional[str] = self._properties.get("private_key") + private_key_password: Optional[str] = self._properties.get( + "private_key_password" + ) + if private_key_password and private_key is None: + exception = SSHTunnelInvalidError() + exception.add(SSHTunnelRequiredFieldValidationError("private_key")) + raise exception diff --git a/tests/integration_tests/databases/ssh_tunnel/__init__.py b/tests/integration_tests/databases/ssh_tunnel/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/tests/integration_tests/databases/ssh_tunnel/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/integration_tests/databases/ssh_tunnel/commands/__init__.py b/tests/integration_tests/databases/ssh_tunnel/commands/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/tests/integration_tests/databases/ssh_tunnel/commands/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/integration_tests/databases/ssh_tunnel/commands/commands_tests.py b/tests/integration_tests/databases/ssh_tunnel/commands/commands_tests.py new file mode 100644 index 000000000000..75e5a55e862c --- /dev/null +++ b/tests/integration_tests/databases/ssh_tunnel/commands/commands_tests.py @@ -0,0 +1,76 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from unittest import mock, skip +from unittest.mock import patch + +import pytest + +from superset import security_manager +from superset.databases.ssh_tunnel.commands.create import CreateSSHTunnelCommand +from superset.databases.ssh_tunnel.commands.delete import DeleteSSHTunnelCommand +from superset.databases.ssh_tunnel.commands.exceptions import ( + SSHTunnelInvalidError, + SSHTunnelNotFoundError, +) +from superset.databases.ssh_tunnel.commands.update import UpdateSSHTunnelCommand +from tests.integration_tests.base_tests import SupersetTestCase + + +class TestCreateSSHTunnelCommand(SupersetTestCase): + @mock.patch("superset.utils.core.g") + def test_create_invalid_database_id(self, mock_g): + mock_g.user = security_manager.find_user("admin") + command = CreateSSHTunnelCommand( + None, + { + "server_address": "127.0.0.1", + "server_port": 5432, + "username": "test_user", + }, + ) + with pytest.raises(SSHTunnelInvalidError) as excinfo: + command.run() + assert str(excinfo.value) == ("SSH Tunnel parameters are invalid.") + + +class TestUpdateSSHTunnelCommand(SupersetTestCase): + @mock.patch("superset.utils.core.g") + def test_update_ssh_tunnel_not_found(self, mock_g): + mock_g.user = security_manager.find_user("admin") + # We have not created a SSH Tunnel yet so id = 1 is invalid + command = UpdateSSHTunnelCommand( + 1, + { + "server_address": "127.0.0.1", + "server_port": 5432, + "username": "test_user", + }, + ) + with pytest.raises(SSHTunnelNotFoundError) as excinfo: + command.run() + assert str(excinfo.value) == ("SSH Tunnel not found.") + + +class TestDeleteSSHTunnelCommand(SupersetTestCase): + @mock.patch("superset.utils.core.g") + def test_delete_ssh_tunnel_not_found(self, mock_g): + mock_g.user = security_manager.find_user("admin") + # We have not created a SSH Tunnel yet so id = 1 is invalid + command = DeleteSSHTunnelCommand(1) + with pytest.raises(SSHTunnelNotFoundError) as excinfo: + command.run() + assert str(excinfo.value) == ("SSH Tunnel not found.") diff --git a/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py b/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py index 2b6b8d9aaebb..2a5738ebd396 100644 --- a/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py +++ b/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py @@ -20,6 +20,8 @@ import pytest from sqlalchemy.orm.session import Session +from superset.databases.ssh_tunnel.commands.exceptions import SSHTunnelInvalidError + def test_create_ssh_tunnel_command() -> None: from superset.databases.ssh_tunnel.commands.create import CreateSSHTunnelCommand @@ -40,3 +42,27 @@ def test_create_ssh_tunnel_command() -> None: assert result is not None assert isinstance(result, SSHTunnel) + + +def test_create_ssh_tunnel_command_invalid_params() -> None: + from superset.databases.ssh_tunnel.commands.create import CreateSSHTunnelCommand + from superset.databases.ssh_tunnel.models import SSHTunnel + from superset.models.core import Database + + db = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://") + + # If we are trying to create a tunnel with a private_key_password + # then a private_key is mandatory + properties = { + "database_id": db.id, + "server_address": "123.132.123.1", + "server_port": "3005", + "username": "foo", + "private_key_password": "bar", + } + + command = CreateSSHTunnelCommand(db.id, properties) + + with pytest.raises(SSHTunnelInvalidError) as excinfo: + command.run() + assert str(excinfo.value) == ("SSH Tunnel parameters are invalid.") diff --git a/tests/unit_tests/databases/ssh_tunnel/commands/update_test.py b/tests/unit_tests/databases/ssh_tunnel/commands/update_test.py index 4b8a1fd2a002..58f90054ccd1 100644 --- a/tests/unit_tests/databases/ssh_tunnel/commands/update_test.py +++ b/tests/unit_tests/databases/ssh_tunnel/commands/update_test.py @@ -20,6 +20,8 @@ import pytest from sqlalchemy.orm.session import Session +from superset.databases.ssh_tunnel.commands.exceptions import SSHTunnelInvalidError + @pytest.fixture def session_with_data(session: Session) -> Iterator[Session]: @@ -67,3 +69,25 @@ def test_update_shh_tunnel_command(session_with_data: Session) -> None: assert result assert isinstance(result, SSHTunnel) assert "Test2" == result.server_address + + +def test_update_shh_tunnel_invalid_params(session_with_data: Session) -> None: + from superset.databases.dao import DatabaseDAO + from superset.databases.ssh_tunnel.commands.update import UpdateSSHTunnelCommand + from superset.databases.ssh_tunnel.models import SSHTunnel + + result = DatabaseDAO.get_ssh_tunnel(1) + + assert result + assert isinstance(result, SSHTunnel) + assert 1 == result.database_id + assert "Test" == result.server_address + + # If we are trying to update a tunnel with a private_key_password + # then a private_key is mandatory + update_payload = {"private_key_password": "pass"} + command = UpdateSSHTunnelCommand(1, update_payload) + + with pytest.raises(SSHTunnelInvalidError) as excinfo: + command.run() + assert str(excinfo.value) == ("SSH Tunnel parameters are invalid.") From be5c0051fd697af17edf0ad43a807cb76f1fecbc Mon Sep 17 00:00:00 2001 From: hughhhh Date: Thu, 8 Dec 2022 17:06:20 -0500 Subject: [PATCH 62/75] add flush to allow database.id to be populated --- superset/databases/commands/create.py | 1 + 1 file changed, 1 insertion(+) diff --git a/superset/databases/commands/create.py b/superset/databases/commands/create.py index 8b3a5f55bafc..3b5874aabe02 100644 --- a/superset/databases/commands/create.py +++ b/superset/databases/commands/create.py @@ -71,6 +71,7 @@ def run(self) -> Model: try: database = DatabaseDAO.create(self._properties, commit=False) database.set_sqlalchemy_uri(database.sqlalchemy_uri) + db.session.flush() ssh_tunnel = None if ssh_tunnel_properties := self._properties.get("ssh_tunnel"): From c636ce7705716af950def37e1d8c9b46e96bc85c Mon Sep 17 00:00:00 2001 From: hughhhh Date: Fri, 9 Dec 2022 11:51:01 -0500 Subject: [PATCH 63/75] make sure to use inspector with context --- superset/models/core.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/superset/models/core.py b/superset/models/core.py index 0694d97fac81..790dd9fe9ced 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -734,7 +734,8 @@ def get_table(self, table_name: str, schema: Optional[str] = None) -> Table: def get_table_comment( self, table_name: str, schema: Optional[str] = None ) -> Optional[str]: - return self.db_engine_spec.get_table_comment(self.inspector, table_name, schema) + with self.get_inspector_with_context() as inspector: + return self.db_engine_spec.get_table_comment(inspector, table_name, schema) def get_columns( self, table_name: str, schema: Optional[str] = None @@ -747,7 +748,8 @@ def get_metrics( table_name: str, schema: Optional[str] = None, ) -> List[MetricType]: - return self.db_engine_spec.get_metrics(self, self.inspector, table_name, schema) + with self.get_inspector_with_context() as inspector: + return self.db_engine_spec.get_metrics(self, inspector, table_name, schema) def get_indexes( self, table_name: str, schema: Optional[str] = None From 908896ffa95eac9669c57a608ef57dd3118607a2 Mon Sep 17 00:00:00 2001 From: hughhhh Date: Sun, 11 Dec 2022 22:18:11 -0500 Subject: [PATCH 64/75] remove id and database_id --- superset/databases/schemas.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/superset/databases/schemas.py b/superset/databases/schemas.py index 6dd7d0d79fcf..5e9e3a552d29 100644 --- a/superset/databases/schemas.py +++ b/superset/databases/schemas.py @@ -366,9 +366,6 @@ class Meta: # pylint: disable=too-few-public-methods class DatabaseSSHTunnel(Schema): - id = fields.Integer() - database_id = fields.Integer() - server_address = fields.String() server_port = fields.Integer() username = fields.String() From e3ef835c3f681c86243f9c9c64417c2c75384b8f Mon Sep 17 00:00:00 2001 From: hughhhh Date: Mon, 12 Dec 2022 11:12:09 -0500 Subject: [PATCH 65/75] uselist --- superset/databases/ssh_tunnel/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/superset/databases/ssh_tunnel/models.py b/superset/databases/ssh_tunnel/models.py index d4ca3504cdf8..17f4628f8070 100644 --- a/superset/databases/ssh_tunnel/models.py +++ b/superset/databases/ssh_tunnel/models.py @@ -47,7 +47,7 @@ class SSHTunnel(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin): ) database: Database = relationship( "Database", - backref=backref("ssh_tunnels", cascade="all, delete-orphan"), + backref=backref("ssh_tunnels", uselist=False, cascade="all, delete-orphan"), foreign_keys=[database_id], ) From c5c50ed1566b367cab8270ff0866bb5393c56228 Mon Sep 17 00:00:00 2001 From: "Hugh A. Miles II" Date: Thu, 15 Dec 2022 14:08:07 -0500 Subject: [PATCH 66/75] feat(ssh-tunnel): ssh manager config + feature flag (#22201) --- superset/config.py | 38 ++++++++++++++++++++++++- superset/constants.py | 2 -- superset/databases/ssh_tunnel/models.py | 4 +-- superset/models/core.py | 12 +++----- 4 files changed, 43 insertions(+), 13 deletions(-) diff --git a/superset/config.py b/superset/config.py index f163997c6ee4..71b220333977 100644 --- a/superset/config.py +++ b/superset/config.py @@ -45,6 +45,7 @@ ) import pkg_resources +import sshtunnel from cachelib.base import BaseCache from celery.schedules import crontab from dateutil import tz @@ -56,6 +57,7 @@ from superset.advanced_data_type.plugins.internet_port import internet_port from superset.advanced_data_type.types import AdvancedDataType from superset.constants import CHANGE_ME_SECRET_KEY +from superset.databases.utils import make_url_safe from superset.jinja_context import BaseTemplateProcessor from superset.reports.types import ReportScheduleExecutor from superset.stats_logger import DummyStatsLogger @@ -471,8 +473,42 @@ def _try_json_readsha(filepath: str, length: int) -> Optional[str]: "DRILL_TO_DETAIL": False, "DATAPANEL_CLOSED_BY_DEFAULT": False, "HORIZONTAL_FILTER_BAR": False, + # Allow users to enable ssh tunneling when creating a DB. + # Users must check whether the DB engine supports SSH Tunnels + # otherwise enabling this flag won't have any effect on the DB. + "SSH_TUNNELING": False, } +# ------------------------------ +# SSH Tunnel +# ------------------------------ +# Allow users to set the host used when connecting to the SSH Tunnel +# as localhost and any other alias (0.0.0.0) +# ---------------------------------------------------------------------- +# | +# -------------+ | +----------+ +# LOCAL | | | REMOTE | :22 SSH +# CLIENT | <== SSH ========> | SERVER | :8080 web service +# -------------+ | +----------+ +# | +# FIREWALL (only port 22 is open) + +# ---------------------------------------------------------------------- +class SSHManager: # pylint: disable=too-few-public-methods + local_bind_address = "127.0.0.1" + + @classmethod + def mutator(cls, sqlalchemy_url: str, server: sshtunnel.SSHTunnelForwarder) -> str: + # override any ssh tunnel configuration object + url = make_url_safe(sqlalchemy_url) + return url.set( + host=cls.local_bind_address, + port=server.local_bind_port, + ) + + +SSH_TUNNEL_MANAGER = SSHManager # pylint: disable=invalid-name + # Feature flags may also be set via 'SUPERSET_FEATURE_' prefixed environment vars. DEFAULT_FEATURE_FLAGS.update( { @@ -1462,7 +1498,7 @@ def EMAIL_HEADER_MUTATOR( # pylint: disable=invalid-name,unused-argument try: # pylint: disable=import-error,wildcard-import,unused-wildcard-import import superset_config - from superset_config import * # type:ignore + from superset_config import * # type: ignore print(f"Loaded your LOCAL configuration at [{superset_config.__file__}]") except Exception: diff --git a/superset/constants.py b/superset/constants.py index c0fbb7c2cd8d..7d759acf6741 100644 --- a/superset/constants.py +++ b/superset/constants.py @@ -34,8 +34,6 @@ NO_TIME_RANGE = "No filter" -SSH_TUNNELLING_LOCAL_BIND_ADDRESS = "127.0.0.1" - class RouteMethod: # pylint: disable=too-few-public-methods """ diff --git a/superset/databases/ssh_tunnel/models.py b/superset/databases/ssh_tunnel/models.py index 17f4628f8070..5f334d8c154b 100644 --- a/superset/databases/ssh_tunnel/models.py +++ b/superset/databases/ssh_tunnel/models.py @@ -23,7 +23,6 @@ from sqlalchemy.orm import backref, relationship from sqlalchemy_utils import EncryptedType -from superset.constants import SSH_TUNNELLING_LOCAL_BIND_ADDRESS from superset.models.core import Database from superset.models.helpers import ( AuditMixinNullable, @@ -32,6 +31,7 @@ ) app_config = current_app.config +ssh_manager = app_config["SSH_TUNNEL_MANAGER"] class SSHTunnel(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin): @@ -74,7 +74,7 @@ def parameters(self, bind_host: str, bind_port: int) -> Dict[str, Any]: "ssh_port": self.server_port, "ssh_username": self.username, "remote_bind_address": (bind_host, bind_port), - "local_bind_address": (SSH_TUNNELLING_LOCAL_BIND_ADDRESS,), + "local_bind_address": (ssh_manager.local_bind_address,), } if self.password: diff --git a/superset/models/core.py b/superset/models/core.py index 790dd9fe9ced..51bd0ce21b6f 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -55,7 +55,7 @@ from sqlalchemy.sql import expression, Select from superset import app, db_engine_specs -from superset.constants import PASSWORD_MASK, SSH_TUNNELLING_LOCAL_BIND_ADDRESS +from superset.constants import PASSWORD_MASK from superset.databases.utils import make_url_safe from superset.db_engine_specs.base import MetricType, TimeGrain from superset.extensions import cache_manager, encrypted_field_factory, security_manager @@ -66,6 +66,7 @@ from superset.utils.memoized import memoized config = app.config +ssh_manager = config["SSH_TUNNEL_MANAGER"] custom_password_store = config["SQLALCHEMY_CUSTOM_PASSWORD_STORE"] stats_logger = config["STATS_LOGGER"] log_query = config["QUERY_LOGGER"] @@ -447,13 +448,8 @@ def _get_sqla_engine( ) if ssh_tunnel_server: - # update sqlalchemy_url - url = make_url_safe(sqlalchemy_url) - sqlalchemy_url = url.set( - host=SSH_TUNNELLING_LOCAL_BIND_ADDRESS, - port=ssh_tunnel_server.local_bind_port, - ) - + # update sqlalchemy_url with ssh tunnel manager info + sqlalchemy_url = ssh_manager.mutator(sqlalchemy_url, ssh_tunnel_server) try: return create_engine(sqlalchemy_url, **params) except Exception as ex: From 06e115b1108df99730729270405b2c69e23e550b Mon Sep 17 00:00:00 2001 From: hughhhh Date: Thu, 15 Dec 2022 14:22:29 -0500 Subject: [PATCH 67/75] update kwarg function name --- superset/databases/ssh_tunnel/models.py | 2 +- superset/models/core.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/superset/databases/ssh_tunnel/models.py b/superset/databases/ssh_tunnel/models.py index 5f334d8c154b..0ec9c60bd0c8 100644 --- a/superset/databases/ssh_tunnel/models.py +++ b/superset/databases/ssh_tunnel/models.py @@ -68,7 +68,7 @@ class SSHTunnel(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin): EncryptedType(sa.String, app_config["SECRET_KEY"]), nullable=True ) - def parameters(self, bind_host: str, bind_port: int) -> Dict[str, Any]: + def kwarg_parameters(self, bind_host: str, bind_port: int) -> Dict[str, Any]: params = { "ssh_address_or_host": self.server_address, "ssh_port": self.server_port, diff --git a/superset/models/core.py b/superset/models/core.py index 51bd0ce21b6f..ee0bc2c4659a 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -385,7 +385,9 @@ def get_sqla_engine_with_context( ): # if ssh_tunnel is available build engine with information url = make_url_safe(self.sqlalchemy_uri_decrypted) - ssh_params = ssh_tunnel.parameters(bind_host=url.host, bind_port=url.port) + ssh_params = ssh_tunnel.kwarg_parameters( + bind_host=url.host, bind_port=url.port + ) engine_context = sshtunnel.open_tunnel(**ssh_params) else: engine_context = nullcontext() From 13ed50d60e0acfaa63eb60d7277b62dd94d5b00c Mon Sep 17 00:00:00 2001 From: "Hugh A. Miles II" Date: Fri, 16 Dec 2022 10:31:47 -0500 Subject: [PATCH 68/75] chore(ssh-tunnel): Move SSHManager to extensions pattern (#22433) --- superset/config.py | 18 +---- superset/databases/ssh_tunnel/models.py | 20 ------ superset/extensions/__init__.py | 2 + superset/extensions/ssh.py | 88 +++++++++++++++++++++++++ superset/initialization/__init__.py | 5 ++ superset/models/core.py | 43 ++++++------ 6 files changed, 121 insertions(+), 55 deletions(-) create mode 100644 superset/extensions/ssh.py diff --git a/superset/config.py b/superset/config.py index 71b220333977..4ace263b152b 100644 --- a/superset/config.py +++ b/superset/config.py @@ -45,7 +45,6 @@ ) import pkg_resources -import sshtunnel from cachelib.base import BaseCache from celery.schedules import crontab from dateutil import tz @@ -57,7 +56,6 @@ from superset.advanced_data_type.plugins.internet_port import internet_port from superset.advanced_data_type.types import AdvancedDataType from superset.constants import CHANGE_ME_SECRET_KEY -from superset.databases.utils import make_url_safe from superset.jinja_context import BaseTemplateProcessor from superset.reports.types import ReportScheduleExecutor from superset.stats_logger import DummyStatsLogger @@ -494,20 +492,8 @@ def _try_json_readsha(filepath: str, length: int) -> Optional[str]: # FIREWALL (only port 22 is open) # ---------------------------------------------------------------------- -class SSHManager: # pylint: disable=too-few-public-methods - local_bind_address = "127.0.0.1" - - @classmethod - def mutator(cls, sqlalchemy_url: str, server: sshtunnel.SSHTunnelForwarder) -> str: - # override any ssh tunnel configuration object - url = make_url_safe(sqlalchemy_url) - return url.set( - host=cls.local_bind_address, - port=server.local_bind_port, - ) - - -SSH_TUNNEL_MANAGER = SSHManager # pylint: disable=invalid-name +SSH_TUNNEL_MANAGER_CLASS = "superset.extensions.ssh.SSHManager" +SSH_TUNNEL_LOCAL_BIND_ADDRESS = "127.0.0.1" # Feature flags may also be set via 'SUPERSET_FEATURE_' prefixed environment vars. DEFAULT_FEATURE_FLAGS.update( diff --git a/superset/databases/ssh_tunnel/models.py b/superset/databases/ssh_tunnel/models.py index 0ec9c60bd0c8..97a914711a72 100644 --- a/superset/databases/ssh_tunnel/models.py +++ b/superset/databases/ssh_tunnel/models.py @@ -15,8 +15,6 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Dict - import sqlalchemy as sa from flask import current_app from flask_appbuilder import Model @@ -31,7 +29,6 @@ ) app_config = current_app.config -ssh_manager = app_config["SSH_TUNNEL_MANAGER"] class SSHTunnel(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin): @@ -67,20 +64,3 @@ class SSHTunnel(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin): private_key_password = sa.Column( EncryptedType(sa.String, app_config["SECRET_KEY"]), nullable=True ) - - def kwarg_parameters(self, bind_host: str, bind_port: int) -> Dict[str, Any]: - params = { - "ssh_address_or_host": self.server_address, - "ssh_port": self.server_port, - "ssh_username": self.username, - "remote_bind_address": (bind_host, bind_port), - "local_bind_address": (ssh_manager.local_bind_address,), - } - - if self.password: - params["ssh_password"] = self.password - elif self.private_key: - params["private_key"] = self.private_key - params["private_key_password"] = self.private_key_password - - return params diff --git a/superset/extensions/__init__.py b/superset/extensions/__init__.py index 1f5882f7492a..cccf3a526fe8 100644 --- a/superset/extensions/__init__.py +++ b/superset/extensions/__init__.py @@ -28,6 +28,7 @@ from flask_wtf.csrf import CSRFProtect from werkzeug.local import LocalProxy +from superset.extensions.ssh import SSHManagerFactory from superset.utils.async_query_manager import AsyncQueryManager from superset.utils.cache_manager import CacheManager from superset.utils.encrypt import EncryptedFieldFactory @@ -127,3 +128,4 @@ def init_app(self, app: Flask) -> None: results_backend_manager = ResultsBackendManager() security_manager = LocalProxy(lambda: appbuilder.sm) talisman = Talisman() +ssh_manager_factory = SSHManagerFactory() diff --git a/superset/extensions/ssh.py b/superset/extensions/ssh.py new file mode 100644 index 000000000000..4ae8d508fcb7 --- /dev/null +++ b/superset/extensions/ssh.py @@ -0,0 +1,88 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import importlib +from typing import TYPE_CHECKING + +from flask import Flask +from sshtunnel import open_tunnel, SSHTunnelForwarder + +from superset.databases.utils import make_url_safe + +if TYPE_CHECKING: + from superset.databases.ssh_tunnel.models import SSHTunnel + + +class SSHManager: + def __init__(self, app: Flask) -> None: + super().__init__() + self.local_bind_address = app.config["SSH_TUNNEL_LOCAL_BIND_ADDRESS"] + + def build_sqla_url( # pylint: disable=no-self-use + self, sqlalchemy_url: str, server: SSHTunnelForwarder + ) -> str: + # override any ssh tunnel configuration object + url = make_url_safe(sqlalchemy_url) + return url.set( + host=server.local_bind_address[0], + port=server.local_bind_port, + ) + + def create_tunnel( + self, + ssh_tunnel: "SSHTunnel", + sqlalchemy_database_uri: str, + ) -> SSHTunnelForwarder: + url = make_url_safe(sqlalchemy_database_uri) + params = { + "ssh_address_or_host": ssh_tunnel.server_address, + "ssh_port": ssh_tunnel.server_port, + "ssh_username": ssh_tunnel.username, + "remote_bind_address": (url.host, url.port), # bind_port, bind_host + "local_bind_address": (self.local_bind_address,), + } + + if ssh_tunnel.password: + params["ssh_password"] = ssh_tunnel.password + elif ssh_tunnel.private_key: + params["private_key"] = ssh_tunnel.private_key + params["private_key_password"] = ssh_tunnel.private_key_password + + return open_tunnel(**params) + + +class SSHManagerFactory: + def __init__(self) -> None: + self._ssh_manager = None + + def init_app(self, app: Flask) -> None: + ssh_manager_fqclass = app.config["SSH_TUNNEL_MANAGER_CLASS"] + ssh_manager_classname = ssh_manager_fqclass[ + ssh_manager_fqclass.rfind(".") + 1 : + ] + ssh_manager_module_name = ssh_manager_fqclass[ + 0 : ssh_manager_fqclass.rfind(".") + ] + ssh_manager_class = getattr( + importlib.import_module(ssh_manager_module_name), ssh_manager_classname + ) + + self._ssh_manager = ssh_manager_class(app) + + @property + def instance(self) -> SSHManager: + return self._ssh_manager # type: ignore diff --git a/superset/initialization/__init__.py b/superset/initialization/__init__.py index 8c53c4c8e7c3..2b02d5106e2d 100644 --- a/superset/initialization/__init__.py +++ b/superset/initialization/__init__.py @@ -45,6 +45,7 @@ migrate, profiling, results_backend_manager, + ssh_manager_factory, talisman, ) from superset.security import SupersetSecurityManager @@ -417,6 +418,7 @@ def init_app_in_ctx(self) -> None: self.configure_data_sources() self.configure_auth_provider() self.configure_async_queries() + self.configure_ssh_manager() # Hook that provides administrators a handle on the Flask APP # after initialization @@ -474,6 +476,9 @@ def init_app(self) -> None: def configure_auth_provider(self) -> None: machine_auth_provider_factory.init_app(self.superset_app) + def configure_ssh_manager(self) -> None: + ssh_manager_factory.init_app(self.superset_app) + def setup_event_logger(self) -> None: _event_logger["event_logger"] = get_event_logger_from_cfg_value( self.superset_app.config.get("EVENT_LOGGER", DBEventLogger()) diff --git a/superset/models/core.py b/superset/models/core.py index ee0bc2c4659a..50c00ac615f3 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -29,7 +29,6 @@ import numpy import pandas as pd import sqlalchemy as sqla -import sshtunnel from flask import g, request from flask_appbuilder import Model from sqlalchemy import ( @@ -58,7 +57,12 @@ from superset.constants import PASSWORD_MASK from superset.databases.utils import make_url_safe from superset.db_engine_specs.base import MetricType, TimeGrain -from superset.extensions import cache_manager, encrypted_field_factory, security_manager +from superset.extensions import ( + cache_manager, + encrypted_field_factory, + security_manager, + ssh_manager_factory, +) from superset.models.helpers import AuditMixinNullable, ImportExportMixin from superset.result_set import SupersetResultSet from superset.utils import cache as cache_util, core as utils @@ -66,7 +70,6 @@ from superset.utils.memoized import memoized config = app.config -ssh_manager = config["SSH_TUNNEL_MANAGER"] custom_password_store = config["SQLALCHEMY_CUSTOM_PASSWORD_STORE"] stats_logger = config["STATS_LOGGER"] log_query = config["QUERY_LOGGER"] @@ -375,29 +378,33 @@ def get_sqla_engine_with_context( source: Optional[utils.QuerySource] = None, override_ssh_tunnel: Optional["SSHTunnel"] = None, ) -> Engine: - ssh_params = {} from superset.databases.dao import ( # pylint: disable=import-outside-toplevel DatabaseDAO, ) - if ssh_tunnel := override_ssh_tunnel or DatabaseDAO.get_ssh_tunnel( + sqlalchemy_uri = self.sqlalchemy_uri_decrypted + engine_context = nullcontext() + ssh_tunnel = override_ssh_tunnel or DatabaseDAO.get_ssh_tunnel( database_id=self.id - ): + ) + + if ssh_tunnel: # if ssh_tunnel is available build engine with information - url = make_url_safe(self.sqlalchemy_uri_decrypted) - ssh_params = ssh_tunnel.kwarg_parameters( - bind_host=url.host, bind_port=url.port + engine_context = ssh_manager_factory.instance.create_tunnel( + ssh_tunnel=ssh_tunnel, + sqlalchemy_database_uri=self.sqlalchemy_uri_decrypted, ) - engine_context = sshtunnel.open_tunnel(**ssh_params) - else: - engine_context = nullcontext() with engine_context as server_context: + if ssh_tunnel: + sqlalchemy_uri = ssh_manager_factory.instance.build_sqla_url( + sqlalchemy_uri, server_context + ) yield self._get_sqla_engine( schema=schema, nullpool=nullpool, source=source, - ssh_tunnel_server=server_context, + sqlalchemy_uri=sqlalchemy_uri, ) def _get_sqla_engine( @@ -405,10 +412,12 @@ def _get_sqla_engine( schema: Optional[str] = None, nullpool: bool = True, source: Optional[utils.QuerySource] = None, - ssh_tunnel_server: Optional[sshtunnel.SSHTunnelForwarder] = None, + sqlalchemy_uri: Optional[str] = None, ) -> Engine: extra = self.get_extra() - sqlalchemy_url = make_url_safe(self.sqlalchemy_uri_decrypted) + sqlalchemy_url = make_url_safe( + sqlalchemy_uri if sqlalchemy_uri else self.sqlalchemy_uri_decrypted + ) sqlalchemy_url = self.db_engine_spec.adjust_database_uri(sqlalchemy_url, schema) effective_username = self.get_effective_user(sqlalchemy_url) # If using MySQL or Presto for example, will set url.username @@ -448,10 +457,6 @@ def _get_sqla_engine( sqlalchemy_url, params = DB_CONNECTION_MUTATOR( sqlalchemy_url, params, effective_username, security_manager, source ) - - if ssh_tunnel_server: - # update sqlalchemy_url with ssh tunnel manager info - sqlalchemy_url = ssh_manager.mutator(sqlalchemy_url, ssh_tunnel_server) try: return create_engine(sqlalchemy_url, **params) except Exception as ex: From 54d51e21d05100b343cf1082c62bb442a8bc08e0 Mon Sep 17 00:00:00 2001 From: hughhhh Date: Fri, 16 Dec 2022 12:31:26 -0500 Subject: [PATCH 69/75] add flag to indicate ssh tunneling is enabled for this engine --- superset/db_engine_specs/base.py | 1 + superset/db_engine_specs/postgres.py | 1 + 2 files changed, 2 insertions(+) diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 87951d396ef1..cff1630c02ed 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -192,6 +192,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods engine_aliases: Set[str] = set() drivers: Dict[str, str] = {} default_driver: Optional[str] = None + allow_ssh_tunneling = False _date_trunc_functions: Dict[str, str] = {} _time_grain_expressions: Dict[Optional[str], str] = {} diff --git a/superset/db_engine_specs/postgres.py b/superset/db_engine_specs/postgres.py index 286b6e80a1ca..3a6a2e17d89d 100644 --- a/superset/db_engine_specs/postgres.py +++ b/superset/db_engine_specs/postgres.py @@ -166,6 +166,7 @@ def epoch_to_dttm(cls) -> str: class PostgresEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin): engine = "postgresql" engine_aliases = {"postgres"} + allow_ssh_tunneling = True default_driver = "psycopg2" sqlalchemy_uri_placeholder = ( From 53eaa63e73a83477fd3236614677b3348f8224eb Mon Sep 17 00:00:00 2001 From: "Hugh A. Miles II" Date: Mon, 19 Dec 2022 14:04:44 -0500 Subject: [PATCH 70/75] Update superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py Co-authored-by: Elizabeth Thompson --- ...0-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py b/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py index 75bad1e53ed5..8b3d436fc04e 100644 --- a/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py +++ b/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py @@ -49,7 +49,7 @@ def upgrade(): # ExtraJSONMixin sa.Column("extra_json", sa.Text(), nullable=True), # ImportExportMixin - sa.Column("uuid", UUIDType(binary=True), primary_key=False, default=uuid4), + sa.Column("uuid", UUIDType(binary=True), primary_key=False, default=uuid4, unique=True, index=True), # SSHTunnelCredentials sa.Column("id", sa.Integer(), primary_key=True), sa.Column("database_id", sa.INTEGER(), sa.ForeignKey("dbs.id"), unique=True), From 8f8faff780fb77dec23698ae443f471e310578d0 Mon Sep 17 00:00:00 2001 From: "Hugh A. Miles II" Date: Mon, 19 Dec 2022 14:09:03 -0500 Subject: [PATCH 71/75] Update superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py Co-authored-by: Elizabeth Thompson --- ...0-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py b/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py index 8b3d436fc04e..1caad6e54041 100644 --- a/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py +++ b/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py @@ -52,7 +52,7 @@ def upgrade(): sa.Column("uuid", UUIDType(binary=True), primary_key=False, default=uuid4, unique=True, index=True), # SSHTunnelCredentials sa.Column("id", sa.Integer(), primary_key=True), - sa.Column("database_id", sa.INTEGER(), sa.ForeignKey("dbs.id"), unique=True), + sa.Column("database_id", sa.INTEGER(), sa.ForeignKey("dbs.id"), unique=True, index=True), sa.Column("server_address", sa.String(256)), sa.Column("server_port", sa.INTEGER()), sa.Column("username", encrypted_field_factory.create(sa.String(256))), From 607c68289ae2e07f297319a50daa32aead4d0b38 Mon Sep 17 00:00:00 2001 From: hughhhh Date: Mon, 19 Dec 2022 20:41:43 -0500 Subject: [PATCH 72/75] fix linting --- ...8ec8595_create_ssh_tunnel_credentials_tbl.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py b/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py index 1caad6e54041..b373020cb14c 100644 --- a/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py +++ b/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py @@ -49,10 +49,23 @@ def upgrade(): # ExtraJSONMixin sa.Column("extra_json", sa.Text(), nullable=True), # ImportExportMixin - sa.Column("uuid", UUIDType(binary=True), primary_key=False, default=uuid4, unique=True, index=True), + sa.Column( + "uuid", + UUIDType(binary=True), + primary_key=False, + default=uuid4, + unique=True, + index=True, + ), # SSHTunnelCredentials sa.Column("id", sa.Integer(), primary_key=True), - sa.Column("database_id", sa.INTEGER(), sa.ForeignKey("dbs.id"), unique=True, index=True), + sa.Column( + "database_id", + sa.INTEGER(), + sa.ForeignKey("dbs.id"), + unique=True, + index=True, + ), sa.Column("server_address", sa.String(256)), sa.Column("server_port", sa.INTEGER()), sa.Column("username", encrypted_field_factory.create(sa.String(256))), From 7cc7bc8250ea200cf8dfe7afa66a637005ee91e5 Mon Sep 17 00:00:00 2001 From: hughhhh Date: Thu, 22 Dec 2022 17:19:40 -0500 Subject: [PATCH 73/75] fix requirements --- requirements/base.txt | 7 ------- 1 file changed, 7 deletions(-) diff --git a/requirements/base.txt b/requirements/base.txt index b3578180d11a..4b7363ca18c6 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -31,15 +31,8 @@ cachelib==0.4.1 # via apache-superset celery==5.2.2 # via apache-superset -<<<<<<< HEAD -cffi==1.14.6 - # via - # cryptography - # pynacl -======= cffi==1.15.1 # via cryptography ->>>>>>> 630c129e3e3e8a48c22d754e5d9943583ac0dae4 click==8.0.4 # via # apache-superset From 394afc16d6297082b775a6b25a183c6161699d39 Mon Sep 17 00:00:00 2001 From: hughhhh Date: Tue, 3 Jan 2023 15:33:14 -0500 Subject: [PATCH 74/75] get df with get_raw_connection function --- superset/models/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/superset/models/core.py b/superset/models/core.py index 0968151898c6..173bd5b59075 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -515,7 +515,7 @@ def _log_query(sql: str) -> None: security_manager, ) - with closing(engine.raw_connection()) as conn: + with self.get_raw_connection(schema=schema) as conn: cursor = conn.cursor() for sql_ in sqls[:-1]: _log_query(sql_) From 9b09fc7791d66ed1676a509614edef93e3fd43d4 Mon Sep 17 00:00:00 2001 From: Antonio Rivero Martinez <38889534+Antonio-RiveroMartnez@users.noreply.github.com> Date: Tue, 3 Jan 2023 22:05:22 +0100 Subject: [PATCH 75/75] feat(ssh_tunnel): APIs for SSH Tunnels (#22199) Co-authored-by: hughhhh --- superset/constants.py | 1 + superset/databases/api.py | 111 +++++++ superset/databases/commands/create.py | 33 +- superset/databases/commands/update.py | 35 +- superset/databases/schemas.py | 1 + .../databases/ssh_tunnel/commands/create.py | 16 +- superset/databases/ssh_tunnel/models.py | 10 + superset/utils/ssh_tunnel.py | 30 ++ .../integration_tests/databases/api_tests.py | 310 ++++++++++++++++++ tests/unit_tests/databases/api_test.py | 144 ++++++++ 10 files changed, 676 insertions(+), 15 deletions(-) create mode 100644 superset/utils/ssh_tunnel.py diff --git a/superset/constants.py b/superset/constants.py index 5091d65a432d..ea7920ff2fd7 100644 --- a/superset/constants.py +++ b/superset/constants.py @@ -139,6 +139,7 @@ class RouteMethod: # pylint: disable=too-few-public-methods "validate_sql": "read", "get_data": "read", "samples": "read", + "delete_ssh_tunnel": "write", } EXTRA_FORM_DATA_APPEND_KEYS = { diff --git a/superset/databases/api.py b/superset/databases/api.py index 3f737ec4da1a..1c75204f7964 100644 --- a/superset/databases/api.py +++ b/superset/databases/api.py @@ -72,6 +72,11 @@ ValidateSQLRequest, ValidateSQLResponse, ) +from superset.databases.ssh_tunnel.commands.delete import DeleteSSHTunnelCommand +from superset.databases.ssh_tunnel.commands.exceptions import ( + SSHTunnelDeleteFailedError, + SSHTunnelNotFoundError, +) from superset.databases.utils import get_table_metadata from superset.db_engine_specs import get_available_engine_specs from superset.errors import ErrorLevel, SupersetError, SupersetErrorType @@ -80,6 +85,7 @@ from superset.models.core import Database from superset.superset_typing import FlaskResponse from superset.utils.core import error_msg_from_exception, parse_js_uri_path_item +from superset.utils.ssh_tunnel import mask_password_info from superset.views.base import json_errors_response from superset.views.base_api import ( BaseSupersetModelRestApi, @@ -107,6 +113,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi): "available", "validate_parameters", "validate_sql", + "delete_ssh_tunnel", } resource_name = "database" class_permission_name = "Database" @@ -219,6 +226,47 @@ class DatabaseRestApi(BaseSupersetModelRestApi): ValidateSQLResponse, ) + @expose("/", methods=["GET"]) + @protect() + @safe + def get(self, pk: int, **kwargs: Any) -> Response: + """Get a database + --- + get: + description: >- + Get a database + parameters: + - in: path + schema: + type: integer + description: The database id + name: pk + responses: + 200: + description: Database + content: + application/json: + schema: + type: object + 400: + $ref: '#/components/responses/400' + 401: + $ref: '#/components/responses/401' + 422: + $ref: '#/components/responses/422' + 500: + $ref: '#/components/responses/500' + """ + data = self.get_headless(pk, **kwargs) + try: + if ssh_tunnel := DatabaseDAO.get_ssh_tunnel(pk): + payload = data.json + payload["result"]["ssh_tunnel"] = ssh_tunnel.data + return payload + return data + except SupersetException as ex: + return self.response(ex.status, message=ex.message) + @expose("/", methods=["POST"]) @protect() @safe @@ -280,6 +328,12 @@ def post(self) -> FlaskResponse: if new_model.driver: item["driver"] = new_model.driver + # Return SSH Tunnel and hide passwords if any + if item.get("ssh_tunnel"): + item["ssh_tunnel"] = mask_password_info( + new_model.ssh_tunnel # pylint: disable=no-member + ) + return self.response(201, id=new_model.id, result=item) except DatabaseInvalidError as ex: return self.response_422(message=ex.normalized_messages()) @@ -361,6 +415,9 @@ def put(self, pk: int) -> Response: item["sqlalchemy_uri"] = changed_model.sqlalchemy_uri if changed_model.parameters: item["parameters"] = changed_model.parameters + # Return SSH Tunnel and hide passwords if any + if item.get("ssh_tunnel"): + item["ssh_tunnel"] = mask_password_info(changed_model.ssh_tunnel) return self.response(200, id=changed_model.id, result=item) except DatabaseNotFoundError: return self.response_404() @@ -1206,3 +1263,57 @@ def validate_parameters(self) -> FlaskResponse: command = ValidateDatabaseParametersCommand(payload) command.run() return self.response(200, message="OK") + + @expose("//ssh_tunnel/", methods=["DELETE"]) + @protect() + @statsd_metrics + @event_logger.log_this_with_context( + action=lambda self, *args, **kwargs: f"{self.__class__.__name__}" + f".delete_ssh_tunnel", + log_to_statsd=False, + ) + def delete_ssh_tunnel(self, pk: int) -> Response: + """Deletes a SSH Tunnel + --- + delete: + description: >- + Deletes a SSH Tunnel. + parameters: + - in: path + schema: + type: integer + name: pk + responses: + 200: + description: SSH Tunnel deleted + content: + application/json: + schema: + type: object + properties: + message: + type: string + 401: + $ref: '#/components/responses/401' + 403: + $ref: '#/components/responses/403' + 404: + $ref: '#/components/responses/404' + 422: + $ref: '#/components/responses/422' + 500: + $ref: '#/components/responses/500' + """ + try: + DeleteSSHTunnelCommand(pk).run() + return self.response(200, message="OK") + except SSHTunnelNotFoundError: + return self.response_404() + except SSHTunnelDeleteFailedError as ex: + logger.error( + "Error deleting SSH Tunnel %s: %s", + self.__class__.__name__, + str(ex), + exc_info=True, + ) + return self.response_422(message=str(ex)) diff --git a/superset/databases/commands/create.py b/superset/databases/commands/create.py index 3b5874aabe02..c826d8283574 100644 --- a/superset/databases/commands/create.py +++ b/superset/databases/commands/create.py @@ -31,7 +31,11 @@ ) from superset.databases.commands.test_connection import TestConnectionDatabaseCommand from superset.databases.dao import DatabaseDAO -from superset.databases.ssh_tunnel.dao import SSHTunnelDAO +from superset.databases.ssh_tunnel.commands.create import CreateSSHTunnelCommand +from superset.databases.ssh_tunnel.commands.exceptions import ( + SSHTunnelCreateFailedError, + SSHTunnelInvalidError, +) from superset.exceptions import SupersetErrorsException from superset.extensions import db, event_logger, security_manager @@ -71,17 +75,28 @@ def run(self) -> Model: try: database = DatabaseDAO.create(self._properties, commit=False) database.set_sqlalchemy_uri(database.sqlalchemy_uri) - db.session.flush() ssh_tunnel = None if ssh_tunnel_properties := self._properties.get("ssh_tunnel"): - ssh_tunnel = SSHTunnelDAO.create( - { - **ssh_tunnel_properties, - "database_id": database.id, - }, - commit=False, - ) + try: + # So database.id is not None + db.session.flush() + ssh_tunnel = CreateSSHTunnelCommand( + database.id, ssh_tunnel_properties + ).run() + except (SSHTunnelInvalidError, SSHTunnelCreateFailedError) as ex: + event_logger.log_with_context( + action=f"db_creation_failed.{ex.__class__.__name__}", + engine=self._properties.get("sqlalchemy_uri", "").split(":")[0], + ) + # So we can show the original message + raise ex + except Exception as ex: + event_logger.log_with_context( + action=f"db_creation_failed.{ex.__class__.__name__}", + engine=self._properties.get("sqlalchemy_uri", "").split(":")[0], + ) + raise DatabaseCreateFailedError() from ex # adding a new database we always want to force refresh schema list schemas = database.get_all_schema_names(cache=False, ssh_tunnel=ssh_tunnel) diff --git a/superset/databases/commands/update.py b/superset/databases/commands/update.py index 80e3a9b54e61..2e5931788ee6 100644 --- a/superset/databases/commands/update.py +++ b/superset/databases/commands/update.py @@ -21,7 +21,7 @@ from marshmallow import ValidationError from superset.commands.base import BaseCommand -from superset.dao.exceptions import DAOUpdateFailedError +from superset.dao.exceptions import DAOCreateFailedError, DAOUpdateFailedError from superset.databases.commands.exceptions import ( DatabaseConnectionFailedError, DatabaseExistsValidationError, @@ -30,6 +30,12 @@ DatabaseUpdateFailedError, ) from superset.databases.dao import DatabaseDAO +from superset.databases.ssh_tunnel.commands.create import CreateSSHTunnelCommand +from superset.databases.ssh_tunnel.commands.exceptions import ( + SSHTunnelCreateFailedError, + SSHTunnelInvalidError, +) +from superset.databases.ssh_tunnel.commands.update import UpdateSSHTunnelCommand from superset.extensions import db, security_manager from superset.models.core import Database from superset.utils.core import DatasourceType @@ -94,10 +100,33 @@ def run(self) -> Model: security_manager.add_permission_view_menu( "schema_access", security_manager.get_schema_perm(database, schema) ) + + if ssh_tunnel_properties := self._properties.get("ssh_tunnel"): + existing_ssh_tunnel_model = DatabaseDAO.get_ssh_tunnel(database.id) + if existing_ssh_tunnel_model is None: + # We couldn't found an existing tunnel so we need to create one + try: + CreateSSHTunnelCommand(database.id, ssh_tunnel_properties).run() + except (SSHTunnelInvalidError, SSHTunnelCreateFailedError) as ex: + # So we can show the original message + raise ex + except Exception as ex: + raise DatabaseUpdateFailedError() from ex + else: + # We found an existing tunnel so we need to update it + try: + UpdateSSHTunnelCommand( + existing_ssh_tunnel_model.id, ssh_tunnel_properties + ).run() + except (SSHTunnelInvalidError, SSHTunnelCreateFailedError) as ex: + # So we can show the original message + raise ex + except Exception as ex: + raise DatabaseUpdateFailedError() from ex + db.session.commit() - except DAOUpdateFailedError as ex: - logger.exception(ex.exception) + except (DAOUpdateFailedError, DAOCreateFailedError) as ex: raise DatabaseUpdateFailedError() from ex return database diff --git a/superset/databases/schemas.py b/superset/databases/schemas.py index 5e9e3a552d29..1732b01ecc2a 100644 --- a/superset/databases/schemas.py +++ b/superset/databases/schemas.py @@ -468,6 +468,7 @@ class Meta: # pylint: disable=too-few-public-methods ) is_managed_externally = fields.Boolean(allow_none=True, default=False) external_url = fields.String(allow_none=True) + ssh_tunnel = fields.Nested(DatabaseSSHTunnel, allow_none=True) class DatabaseTestConnectionSchema(Schema, DatabaseParametersSchemaMixin): diff --git a/superset/databases/ssh_tunnel/commands/create.py b/superset/databases/ssh_tunnel/commands/create.py index b2e62f340b0b..9c17149ba3d0 100644 --- a/superset/databases/ssh_tunnel/commands/create.py +++ b/superset/databases/ssh_tunnel/commands/create.py @@ -28,7 +28,7 @@ SSHTunnelRequiredFieldValidationError, ) from superset.databases.ssh_tunnel.dao import SSHTunnelDAO -from superset.extensions import event_logger +from superset.extensions import db, event_logger logger = logging.getLogger(__name__) @@ -39,12 +39,22 @@ def __init__(self, database_id: int, data: Dict[str, Any]): self._properties["database_id"] = database_id def run(self) -> Model: - self.validate() - try: + # Start nested transaction since we are always creating the tunnel + # through a DB command (Create or Update). Without this, we cannot + # safely rollback changes to databases if any, i.e, things like + # test_do_not_create_database_if_ssh_tunnel_creation_fails test will fail + db.session.begin_nested() + self.validate() tunnel = SSHTunnelDAO.create(self._properties, commit=False) except DAOCreateFailedError as ex: + # Rollback nested transaction + db.session.rollback() raise SSHTunnelCreateFailedError() from ex + except SSHTunnelInvalidError as ex: + # Rollback nested transaction + db.session.rollback() + raise ex return tunnel diff --git a/superset/databases/ssh_tunnel/models.py b/superset/databases/ssh_tunnel/models.py index 97a914711a72..79e8b918d9e1 100644 --- a/superset/databases/ssh_tunnel/models.py +++ b/superset/databases/ssh_tunnel/models.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. +from typing import Any, Dict + import sqlalchemy as sa from flask import current_app from flask_appbuilder import Model @@ -64,3 +66,11 @@ class SSHTunnel(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin): private_key_password = sa.Column( EncryptedType(sa.String, app_config["SECRET_KEY"]), nullable=True ) + + @property + def data(self) -> Dict[str, Any]: + return { + "server_address": self.server_address, + "server_port": self.server_port, + "username": self.username, + } diff --git a/superset/utils/ssh_tunnel.py b/superset/utils/ssh_tunnel.py new file mode 100644 index 000000000000..6562a8bbb570 --- /dev/null +++ b/superset/utils/ssh_tunnel.py @@ -0,0 +1,30 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, Dict + +from superset.constants import PASSWORD_MASK + + +def mask_password_info(ssh_tunnel: Dict[str, Any]) -> Dict[str, Any]: + if ssh_tunnel.pop("password", None) is not None: + ssh_tunnel["password"] = PASSWORD_MASK + if ssh_tunnel.pop("private_key", None) is not None: + ssh_tunnel["private_key"] = PASSWORD_MASK + if ssh_tunnel.pop("private_key_password", None) is not None: + ssh_tunnel["private_key_password"] = PASSWORD_MASK + return ssh_tunnel diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py index 8a96184b8195..aeb74ec91e4f 100644 --- a/tests/integration_tests/databases/api_tests.py +++ b/tests/integration_tests/databases/api_tests.py @@ -35,6 +35,8 @@ from superset import db, security_manager from superset.connectors.sqla.models import SqlaTable +from superset.databases.ssh_tunnel.models import SSHTunnel +from superset.databases.utils import make_url_safe from superset.db_engine_specs.mysql import MySQLEngineSpec from superset.db_engine_specs.postgres import PostgresEngineSpec from superset.db_engine_specs.redshift import RedshiftEngineSpec @@ -280,6 +282,314 @@ def test_create_database(self): db.session.delete(model) db.session.commit() + @mock.patch( + "superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run", + ) + @mock.patch( + "superset.models.core.Database.get_all_schema_names", + ) + def test_create_database_with_ssh_tunnel( + self, mock_test_connection_database_command_run, mock_get_all_schema_names + ): + """ + Database API: Test create with SSH Tunnel + """ + self.login(username="admin") + example_db = get_example_database() + if example_db.backend == "sqlite": + return + ssh_tunnel_properties = { + "server_address": "123.132.123.1", + "server_port": 8080, + "username": "foo", + "password": "bar", + } + database_data = { + "database_name": "test-db-with-ssh-tunnel", + "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, + "ssh_tunnel": ssh_tunnel_properties, + } + uri = "api/v1/database/" + rv = self.client.post(uri, json=database_data) + response = json.loads(rv.data.decode("utf-8")) + self.assertEqual(rv.status_code, 201) + model_ssh_tunnel = ( + db.session.query(SSHTunnel) + .filter(SSHTunnel.database_id == response.get("id")) + .one() + ) + self.assertEqual(model_ssh_tunnel.database_id, response.get("id")) + # Cleanup + model = db.session.query(Database).get(response.get("id")) + db.session.delete(model) + db.session.commit() + + @mock.patch( + "superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run", + ) + @mock.patch( + "superset.models.core.Database.get_all_schema_names", + ) + def test_update_database_with_ssh_tunnel( + self, mock_test_connection_database_command_run, mock_get_all_schema_names + ): + """ + Database API: Test update with SSH Tunnel + """ + self.login(username="admin") + example_db = get_example_database() + if example_db.backend == "sqlite": + return + ssh_tunnel_properties = { + "server_address": "123.132.123.1", + "server_port": 8080, + "username": "foo", + "password": "bar", + } + database_data = { + "database_name": "test-db-with-ssh-tunnel", + "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, + } + database_data_with_ssh_tunnel = { + "database_name": "test-db-with-ssh-tunnel", + "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, + "ssh_tunnel": ssh_tunnel_properties, + } + + uri = "api/v1/database/" + rv = self.client.post(uri, json=database_data) + response = json.loads(rv.data.decode("utf-8")) + self.assertEqual(rv.status_code, 201) + + uri = "api/v1/database/{}".format(response.get("id")) + rv = self.client.put(uri, json=database_data_with_ssh_tunnel) + response_update = json.loads(rv.data.decode("utf-8")) + self.assertEqual(rv.status_code, 200) + + model_ssh_tunnel = ( + db.session.query(SSHTunnel) + .filter(SSHTunnel.database_id == response_update.get("id")) + .one() + ) + self.assertEqual(model_ssh_tunnel.database_id, response_update.get("id")) + # Cleanup + model = db.session.query(Database).get(response.get("id")) + db.session.delete(model) + db.session.commit() + + @mock.patch( + "superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run", + ) + @mock.patch( + "superset.models.core.Database.get_all_schema_names", + ) + def test_update_ssh_tunnel_via_database_api( + self, mock_test_connection_database_command_run, mock_get_all_schema_names + ): + """ + Database API: Test update with SSH Tunnel + """ + self.login(username="admin") + example_db = get_example_database() + + if example_db.backend == "sqlite": + return + initial_ssh_tunnel_properties = { + "server_address": "123.132.123.1", + "server_port": 8080, + "username": "foo", + "password": "bar", + } + updated_ssh_tunnel_properties = { + "server_address": "123.132.123.1", + "server_port": 8080, + "username": "Test", + } + database_data_with_ssh_tunnel = { + "database_name": "test-db-with-ssh-tunnel", + "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, + "ssh_tunnel": initial_ssh_tunnel_properties, + } + database_data_with_ssh_tunnel_update = { + "database_name": "test-db-with-ssh-tunnel", + "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, + "ssh_tunnel": updated_ssh_tunnel_properties, + } + + uri = "api/v1/database/" + rv = self.client.post(uri, json=database_data_with_ssh_tunnel) + response = json.loads(rv.data.decode("utf-8")) + self.assertEqual(rv.status_code, 201) + model_ssh_tunnel = ( + db.session.query(SSHTunnel) + .filter(SSHTunnel.database_id == response.get("id")) + .one() + ) + self.assertEqual(model_ssh_tunnel.database_id, response.get("id")) + self.assertEqual(model_ssh_tunnel.username, "foo") + uri = "api/v1/database/{}".format(response.get("id")) + rv = self.client.put(uri, json=database_data_with_ssh_tunnel_update) + response_update = json.loads(rv.data.decode("utf-8")) + self.assertEqual(rv.status_code, 200) + model_ssh_tunnel = ( + db.session.query(SSHTunnel) + .filter(SSHTunnel.database_id == response_update.get("id")) + .one() + ) + self.assertEqual(model_ssh_tunnel.database_id, response_update.get("id")) + self.assertEqual(model_ssh_tunnel.username, "Test") + self.assertEqual(model_ssh_tunnel.server_address, "123.132.123.1") + self.assertEqual(model_ssh_tunnel.server_port, 8080) + # Cleanup + model = db.session.query(Database).get(response.get("id")) + db.session.delete(model) + db.session.commit() + + @mock.patch( + "superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run", + ) + @mock.patch( + "superset.models.core.Database.get_all_schema_names", + ) + def test_cascade_delete_ssh_tunnel( + self, mock_test_connection_database_command_run, mock_get_all_schema_names + ): + """ + Database API: Test create with SSH Tunnel + """ + self.login(username="admin") + example_db = get_example_database() + if example_db.backend == "sqlite": + return + ssh_tunnel_properties = { + "server_address": "123.132.123.1", + "server_port": 8080, + "username": "foo", + "password": "bar", + } + database_data = { + "database_name": "test-db-with-ssh-tunnel", + "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, + "ssh_tunnel": ssh_tunnel_properties, + } + + uri = "api/v1/database/" + rv = self.client.post(uri, json=database_data) + response = json.loads(rv.data.decode("utf-8")) + self.assertEqual(rv.status_code, 201) + model_ssh_tunnel = ( + db.session.query(SSHTunnel) + .filter(SSHTunnel.database_id == response.get("id")) + .one() + ) + self.assertEqual(model_ssh_tunnel.database_id, response.get("id")) + # Cleanup + model = db.session.query(Database).get(response.get("id")) + db.session.delete(model) + db.session.commit() + model_ssh_tunnel = ( + db.session.query(SSHTunnel) + .filter(SSHTunnel.database_id == response.get("id")) + .one_or_none() + ) + assert model_ssh_tunnel is None + + @mock.patch( + "superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run", + ) + @mock.patch( + "superset.models.core.Database.get_all_schema_names", + ) + def test_do_not_create_database_if_ssh_tunnel_creation_fails( + self, mock_test_connection_database_command_run, mock_get_all_schema_names + ): + """ + Database API: Test create with SSH Tunnel + """ + self.login(username="admin") + example_db = get_example_database() + if example_db.backend == "sqlite": + return + ssh_tunnel_properties = { + "server_address": "123.132.123.1", + } + database_data = { + "database_name": "test-db-failure-ssh-tunnel", + "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, + "ssh_tunnel": ssh_tunnel_properties, + } + fail_message = {"message": "SSH Tunnel parameters are invalid."} + + uri = "api/v1/database/" + rv = self.client.post(uri, json=database_data) + response = json.loads(rv.data.decode("utf-8")) + self.assertEqual(rv.status_code, 422) + model_ssh_tunnel = ( + db.session.query(SSHTunnel) + .filter(SSHTunnel.database_id == response.get("id")) + .one_or_none() + ) + assert model_ssh_tunnel is None + self.assertEqual(response, fail_message) + # Cleanup + model = ( + db.session.query(Database) + .filter(Database.database_name == "test-db-failure-ssh-tunnel") + .one_or_none() + ) + # the DB should not be created + assert model is None + + @mock.patch( + "superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run", + ) + @mock.patch( + "superset.models.core.Database.get_all_schema_names", + ) + def test_get_database_returns_related_ssh_tunnel( + self, mock_test_connection_database_command_run, mock_get_all_schema_names + ): + """ + Database API: Test GET Database returns its related SSH Tunnel + """ + self.login(username="admin") + example_db = get_example_database() + if example_db.backend == "sqlite": + return + ssh_tunnel_properties = { + "server_address": "123.132.123.1", + "server_port": 8080, + "username": "foo", + "password": "bar", + } + database_data = { + "database_name": "test-db-with-ssh-tunnel", + "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, + "ssh_tunnel": ssh_tunnel_properties, + } + response_ssh_tunnel = { + "server_address": "123.132.123.1", + "server_port": 8080, + "username": "foo", + "password": "XXXXXXXXXX", + } + + uri = "api/v1/database/" + rv = self.client.post(uri, json=database_data) + response = json.loads(rv.data.decode("utf-8")) + self.assertEqual(rv.status_code, 201) + model_ssh_tunnel = ( + db.session.query(SSHTunnel) + .filter(SSHTunnel.database_id == response.get("id")) + .one() + ) + self.assertEqual(model_ssh_tunnel.database_id, response.get("id")) + self.assertEqual(response.get("result")["ssh_tunnel"], response_ssh_tunnel) + # Cleanup + model = db.session.query(Database).get(response.get("id")) + db.session.delete(model) + db.session.commit() + def test_create_database_invalid_configuration_method(self): """ Database API: Test create with an invalid configuration method. diff --git a/tests/unit_tests/databases/api_test.py b/tests/unit_tests/databases/api_test.py index d6f8897c4a09..fe4211289caf 100644 --- a/tests/unit_tests/databases/api_test.py +++ b/tests/unit_tests/databases/api_test.py @@ -191,3 +191,147 @@ def test_non_zip_import(client: Any, full_api_access: None) -> None: } ] } + + +def test_delete_ssh_tunnel( + mocker: MockFixture, + app: Any, + session: Session, + client: Any, + full_api_access: None, +) -> None: + """ + Test that we can delete SSH Tunnel + """ + with app.app_context(): + from superset.databases.api import DatabaseRestApi + from superset.databases.dao import DatabaseDAO + from superset.databases.ssh_tunnel.models import SSHTunnel + from superset.models.core import Database + + DatabaseRestApi.datamodel.session = session + + # create table for databases + Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member + + # Create our Database + database = Database( + database_name="my_database", + sqlalchemy_uri="gsheets://", + encrypted_extra=json.dumps( + { + "service_account_info": { + "type": "service_account", + "project_id": "black-sanctum-314419", + "private_key_id": "259b0d419a8f840056158763ff54d8b08f7b8173", + "private_key": "SECRET", + "client_email": "google-spreadsheets-demo-servi@black-sanctum-314419.iam.gserviceaccount.com", + "client_id": "SSH_TUNNEL_CREDENTIALS_CLIENT", + "auth_uri": "https://accounts.google.com/o/oauth2/auth", + "token_uri": "https://oauth2.googleapis.com/token", + "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", + "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/google-spreadsheets-demo-servi%40black-sanctum-314419.iam.gserviceaccount.com", + }, + } + ), + ) + session.add(database) + session.commit() + + # mock the lookup so that we don't need to include the driver + mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets") + mocker.patch("superset.utils.log.DBEventLogger.log") + + # Create our SSHTunnel + tunnel = SSHTunnel( + database_id=1, + database=database, + ) + + session.add(tunnel) + session.commit() + + # Get our recently created SSHTunnel + response_tunnel = DatabaseDAO.get_ssh_tunnel(1) + assert response_tunnel + assert isinstance(response_tunnel, SSHTunnel) + assert 1 == response_tunnel.database_id + + # Delete the recently created SSHTunnel + response_delete_tunnel = client.delete("/api/v1/database/1/ssh_tunnel/") + assert response_delete_tunnel.json["message"] == "OK" + + response_tunnel = DatabaseDAO.get_ssh_tunnel(1) + assert response_tunnel is None + + +def test_delete_ssh_tunnel_not_found( + mocker: MockFixture, + app: Any, + session: Session, + client: Any, + full_api_access: None, +) -> None: + """ + Test that we cannot delete a tunnel that does not exist + """ + with app.app_context(): + from superset.databases.api import DatabaseRestApi + from superset.databases.dao import DatabaseDAO + from superset.databases.ssh_tunnel.models import SSHTunnel + from superset.models.core import Database + + DatabaseRestApi.datamodel.session = session + + # create table for databases + Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member + + # Create our Database + database = Database( + database_name="my_database", + sqlalchemy_uri="gsheets://", + encrypted_extra=json.dumps( + { + "service_account_info": { + "type": "service_account", + "project_id": "black-sanctum-314419", + "private_key_id": "259b0d419a8f840056158763ff54d8b08f7b8173", + "private_key": "SECRET", + "client_email": "google-spreadsheets-demo-servi@black-sanctum-314419.iam.gserviceaccount.com", + "client_id": "SSH_TUNNEL_CREDENTIALS_CLIENT", + "auth_uri": "https://accounts.google.com/o/oauth2/auth", + "token_uri": "https://oauth2.googleapis.com/token", + "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", + "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/google-spreadsheets-demo-servi%40black-sanctum-314419.iam.gserviceaccount.com", + }, + } + ), + ) + session.add(database) + session.commit() + + # mock the lookup so that we don't need to include the driver + mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets") + mocker.patch("superset.utils.log.DBEventLogger.log") + + # Create our SSHTunnel + tunnel = SSHTunnel( + database_id=1, + database=database, + ) + + session.add(tunnel) + session.commit() + + # Delete the recently created SSHTunnel + response_delete_tunnel = client.delete("/api/v1/database/2/ssh_tunnel/") + assert response_delete_tunnel.json["message"] == "Not found" + + # Get our recently created SSHTunnel + response_tunnel = DatabaseDAO.get_ssh_tunnel(1) + assert response_tunnel + assert isinstance(response_tunnel, SSHTunnel) + assert 1 == response_tunnel.database_id + + response_tunnel = DatabaseDAO.get_ssh_tunnel(2) + assert response_tunnel is None