Skip to content

Commit

Permalink
fix: Refactor SQL username logic (#19914)
Browse files Browse the repository at this point in the history
Co-authored-by: John Bodley <john.bodley@airbnb.com>
  • Loading branch information
john-bodley and John Bodley committed May 13, 2022
1 parent fff9ad0 commit 449d08b
Show file tree
Hide file tree
Showing 22 changed files with 388 additions and 340 deletions.
38 changes: 25 additions & 13 deletions superset/cli/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

import superset.utils.database as database_utils
from superset.extensions import db
from superset.utils.core import override_user
from superset.utils.encrypt import SecretsMigrator

logger = logging.getLogger(__name__)
Expand All @@ -54,23 +55,34 @@ def set_database_uri(database_name: str, uri: str, skip_create: bool) -> None:

@click.command()
@with_appcontext
def update_datasources_cache() -> None:
@click.option(
"--username",
"-u",
default=None,
help=(
"Specify which user should execute the underlying SQL queries. If undefined "
"defaults to the user registered with the database connection."
),
)
def update_datasources_cache(username: Optional[str]) -> None:
"""Refresh sqllab datasources cache"""
# pylint: disable=import-outside-toplevel
from superset import security_manager
from superset.models.core import Database

for database in db.session.query(Database).all():
if database.allow_multi_schema_metadata_fetch:
print("Fetching {} datasources ...".format(database.name))
try:
database.get_all_table_names_in_database(
force=True, cache=True, cache_timeout=24 * 60 * 60
)
database.get_all_view_names_in_database(
force=True, cache=True, cache_timeout=24 * 60 * 60
)
except Exception as ex: # pylint: disable=broad-except
print("{}".format(str(ex)))
with override_user(security_manager.find_user(username)):
for database in db.session.query(Database).all():
if database.allow_multi_schema_metadata_fetch:
print("Fetching {} datasources ...".format(database.name))
try:
database.get_all_table_names_in_database(
force=True, cache=True, cache_timeout=24 * 60 * 60
)
database.get_all_view_names_in_database(
force=True, cache=True, cache_timeout=24 * 60 * 60
)
except Exception as ex: # pylint: disable=broad-except
print("{}".format(str(ex)))


@click.command()
Expand Down
11 changes: 8 additions & 3 deletions superset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,7 @@ def _try_json_readsha(filepath: str, length: int) -> Optional[str]:
# database,
# query,
# schema=None,
# user=None,
# user=None, # TODO(john-bodley): Deprecate in 3.0.
# client=None,
# security_manager=None,
# log_params=None,
Expand Down Expand Up @@ -1020,9 +1020,14 @@ def CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC( # pylint: disable=invalid-name
# The use case is can be around adding some sort of comment header
# with information such as the username and worker node information
#
# def SQL_QUERY_MUTATOR(sql, user_name=user_name, security_manager=security_manager, database=database):
# def SQL_QUERY_MUTATOR(
# sql,
# user_name=user_name, # TODO(john-bodley): Deprecate in 3.0.
# security_manager=security_manager,
# database=database,
# ):
# dttm = datetime.now().isoformat()
# return f"-- [SQL LAB] {username} {dttm}\n{sql}"
# return f"-- [SQL LAB] {user_name} {dttm}\n{sql}"
# For backward compatibility, you can unpack any of the above arguments in your
# function definition, but keep the **kwargs as the last argument to allow new args
# to be added later without any errors.
Expand Down
4 changes: 2 additions & 2 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@
from superset.utils.core import (
GenericDataType,
get_column_name,
get_username,
is_adhoc_column,
MediumText,
QueryObjectFilterClause,
Expand Down Expand Up @@ -917,10 +918,9 @@ def mutate_query_from_config(self, sql: str) -> str:
Typically adds comments to the query with context"""
sql_query_mutator = config["SQL_QUERY_MUTATOR"]
if sql_query_mutator:
username = utils.get_username()
sql = sql_query_mutator(
sql,
user_name=username,
user_name=get_username(), # TODO(john-bodley): Deprecate in 3.0.
security_manager=security_manager,
database=self.database,
)
Expand Down
74 changes: 38 additions & 36 deletions superset/databases/commands/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from superset.exceptions import SupersetSecurityException, SupersetTimeoutException
from superset.extensions import event_logger
from superset.models.core import Database
from superset.utils.core import override_user

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -74,42 +75,43 @@ def run(self) -> None:

database.set_sqlalchemy_uri(uri)
database.db_engine_spec.mutate_db_for_connection_test(database)
username = self._actor.username if self._actor is not None else None
engine = database.get_sqla_engine(user_name=username)
event_logger.log_with_context(
action="test_connection_attempt",
engine=database.db_engine_spec.__name__,
)
with closing(engine.raw_connection()) as conn:
try:
alive = func_timeout(
int(
app.config[
"TEST_DATABASE_CONNECTION_TIMEOUT"
].total_seconds()
),
engine.dialect.do_ping,
args=(conn,),
)
except (sqlite3.ProgrammingError, RuntimeError):
# SQLite can't run on a separate thread, so ``func_timeout`` fails
# RuntimeError catches the equivalent error from duckdb.
alive = engine.dialect.do_ping(conn)
except FunctionTimedOut as ex:
raise SupersetTimeoutException(
error_type=SupersetErrorType.CONNECTION_DATABASE_TIMEOUT,
message=(
"Please check your connection details and database settings, "
"and ensure that your database is accepting connections, "
"then try connecting again."
),
level=ErrorLevel.ERROR,
extra={"sqlalchemy_uri": database.sqlalchemy_uri},
) from ex
except Exception: # pylint: disable=broad-except
alive = False
if not alive:
raise DBAPIError(None, None, None)

with override_user(self._actor):
engine = database.get_sqla_engine()
event_logger.log_with_context(
action="test_connection_attempt",
engine=database.db_engine_spec.__name__,
)
with closing(engine.raw_connection()) as conn:
try:
alive = func_timeout(
int(
app.config[
"TEST_DATABASE_CONNECTION_TIMEOUT"
].total_seconds()
),
engine.dialect.do_ping,
args=(conn,),
)
except (sqlite3.ProgrammingError, RuntimeError):
# SQLite can't run on a separate thread, so ``func_timeout`` fails
# RuntimeError catches the equivalent error from duckdb.
alive = engine.dialect.do_ping(conn)
except FunctionTimedOut as ex:
raise SupersetTimeoutException(
error_type=SupersetErrorType.CONNECTION_DATABASE_TIMEOUT,
message=(
"Please check your connection details and database "
"settings, and ensure that your database is accepting "
"connections, then try connecting again."
),
level=ErrorLevel.ERROR,
extra={"sqlalchemy_uri": database.sqlalchemy_uri},
) from ex
except Exception: # pylint: disable=broad-except
alive = False
if not alive:
raise DBAPIError(None, None, None)

# Log succesful connection test with engine
event_logger.log_with_context(
Expand Down
34 changes: 18 additions & 16 deletions superset/databases/commands/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.extensions import event_logger
from superset.models.core import Database
from superset.utils.core import override_user

BYPASS_VALIDATION_ENGINES = {"bigquery"}

Expand Down Expand Up @@ -115,22 +116,23 @@ def run(self) -> None:
)
database.set_sqlalchemy_uri(sqlalchemy_uri)
database.db_engine_spec.mutate_db_for_connection_test(database)
username = self._actor.username if self._actor is not None else None
engine = database.get_sqla_engine(user_name=username)
try:
with closing(engine.raw_connection()) as conn:
alive = engine.dialect.do_ping(conn)
except Exception as ex:
url = make_url_safe(sqlalchemy_uri)
context = {
"hostname": url.host,
"password": url.password,
"port": url.port,
"username": url.username,
"database": url.database,
}
errors = database.db_engine_spec.extract_errors(ex, context)
raise DatabaseTestConnectionFailedError(errors) from ex

with override_user(self._actor):
engine = database.get_sqla_engine()
try:
with closing(engine.raw_connection()) as conn:
alive = engine.dialect.do_ping(conn)
except Exception as ex:
url = make_url_safe(sqlalchemy_uri)
context = {
"hostname": url.host,
"password": url.password,
"port": url.port,
"username": url.username,
"database": url.database,
}
errors = database.db_engine_spec.extract_errors(ex, context)
raise DatabaseTestConnectionFailedError(errors) from ex

if not alive:
raise DatabaseOfflineError(
Expand Down
21 changes: 6 additions & 15 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
import sqlparse
from apispec import APISpec
from apispec.ext.marshmallow import MarshmallowPlugin
from flask import current_app, g
from flask import current_app
from flask_babel import gettext as __, lazy_gettext as _
from marshmallow import fields, Schema
from marshmallow.validate import Range
Expand All @@ -64,7 +64,7 @@
from superset.sql_parse import ParsedQuery, Table
from superset.superset_typing import ResultSetColumnType
from superset.utils import core as utils
from superset.utils.core import ColumnSpec, GenericDataType
from superset.utils.core import ColumnSpec, GenericDataType, get_username
from superset.utils.hashing import md5_sha_from_str
from superset.utils.network import is_hostname_valid, is_port_open

Expand Down Expand Up @@ -392,10 +392,7 @@ def get_engine(
schema: Optional[str] = None,
source: Optional[str] = None,
) -> Engine:
user_name = utils.get_username()
return database.get_sqla_engine(
schema=schema, nullpool=True, user_name=user_name, source=source
)
return database.get_sqla_engine(schema=schema, source=source)

@classmethod
def get_timestamp_expr(
Expand Down Expand Up @@ -1158,15 +1155,12 @@ def query_cost_formatter(
raise Exception("Database does not support cost estimation")

@classmethod
def process_statement(
cls, statement: str, database: "Database", user_name: str
) -> str:
def process_statement(cls, statement: str, database: "Database") -> str:
"""
Process a SQL statement by stripping and mutating it.
:param statement: A single SQL statement
:param database: Database instance
:param user_name: Effective username
:return: Dictionary with different costs
"""
parsed_query = ParsedQuery(statement)
Expand All @@ -1175,7 +1169,7 @@ def process_statement(
if sql_query_mutator:
sql = sql_query_mutator(
sql,
user_name=user_name,
user_name=get_username(), # TODO(john-bodley): Deprecate in 3.0.
security_manager=security_manager,
database=database,
)
Expand All @@ -1198,7 +1192,6 @@ def estimate_query_cost(
if not cls.get_allow_cost_estimate(extra):
raise Exception("Database does not support cost estimation")

user_name = g.user.username if g.user and hasattr(g.user, "username") else None
parsed_query = sql_parse.ParsedQuery(sql)
statements = parsed_query.get_statements()

Expand All @@ -1207,9 +1200,7 @@ def estimate_query_cost(
with closing(engine.raw_connection()) as conn:
cursor = conn.cursor()
for statement in statements:
processed_statement = cls.process_statement(
statement, database, user_name
)
processed_statement = cls.process_statement(statement, database)
costs.append(cls.estimate_statement_cost(processed_statement, cursor))
return costs

Expand Down
Loading

0 comments on commit 449d08b

Please sign in to comment.