Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing confusion when selecting schema across engines #2572

Merged
merged 1 commit into from
Apr 10, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions superset/db_engine_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,27 @@ def extract_error_message(cls, e):
"""Extract error message for queries"""
return utils.error_msg_from_exception(e)

@classmethod
def adjust_database_uri(cls, uri, selected_schema):
"""Based on a URI and selected schema, return a new URI

The URI here represents the URI as entered when saving the database,
``selected_schema`` is the schema currently active presumably in
the SQL Lab dropdown. Based on that, for some database engine,
we can return a new altered URI that connects straight to the
active schema, meaning the users won't have to prefix the object
names by the schema name.

Some databases engines have 2 level of namespacing: database and
schema (postgres, oracle, mssql, ...)
For those it's probably better to not alter the database
component of the URI with the schema name, it won't work.

Some database drivers like presto accept "{catalog}/{schema}" in
the database component of the URL, that can be handled here.
"""
return uri

@classmethod
def sql_preprocessor(cls, sql):
"""If the SQL needs to be altered prior to running it
Expand Down Expand Up @@ -290,6 +311,12 @@ def convert_dttm(cls, target_type, dttm):
dttm.strftime('%Y-%m-%d %H:%M:%S'))
return "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S'))

@classmethod
def adjust_database_uri(cls, uri, selected_schema=None):
if selected_schema:
uri.database = selected_schema
return uri

@classmethod
def epoch_to_dttm(cls):
return "from_unixtime({col})"
Expand Down Expand Up @@ -328,6 +355,17 @@ def patch(cls):
from superset.db_engines import presto as patched_presto
presto.Cursor.cancel = patched_presto.cancel

@classmethod
def adjust_database_uri(cls, uri, selected_schema=None):
database = uri.database
if selected_schema:
if '/' in database:
database = database.split('/')[0] + '/' + selected_schema
else:
database += '/' + selected_schema
uri.database = database
return uri

@classmethod
def convert_dttm(cls, target_type, dttm):
tt = target_type.upper()
Expand Down
25 changes: 4 additions & 21 deletions superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,26 +560,10 @@ def set_sqlalchemy_uri(self, uri):

def get_sqla_engine(self, schema=None):
extra = self.get_extra()
url = make_url(self.sqlalchemy_uri_decrypted)
uri = make_url(self.sqlalchemy_uri_decrypted)
params = extra.get('engine_params', {})
url.database = self.get_database_for_various_backend(url, schema)
return create_engine(url, **params)

def get_database_for_various_backend(self, uri, default_database=None):
database = uri.database
if self.backend == 'presto' and default_database:
if '/' in database:
database = database.split('/')[0] + '/' + default_database
else:
database += '/' + default_database
# Postgres and Redshift use the concept of schema as a logical entity
# on top of the database, so the database should not be changed
# even if passed default_database
elif self.backend in ('redshift', 'postgresql', 'sqlite'):
pass
elif default_database:
database = default_database
return database
uri = self.db_engine_spec.adjust_database_uri(uri, schema)
return create_engine(uri, **params)

def get_reserved_words(self):
return self.get_sqla_engine().dialect.preparer.reserved_words
Expand Down Expand Up @@ -662,9 +646,8 @@ def all_schema_names(self):

@property
def db_engine_spec(self):
engine_name = self.get_sqla_engine().name or 'base'
return db_engine_specs.engines.get(
engine_name, db_engine_specs.BaseEngineSpec)
self.backend, db_engine_specs.BaseEngineSpec)

def grains(self):
"""Defines time granularity database-specific expressions.
Expand Down
5 changes: 3 additions & 2 deletions superset/views/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,8 @@ def generate_download_headers(extension):
class DatabaseView(SupersetModelView, DeleteMixin): # noqa
datamodel = SQLAInterface(models.Database)
list_columns = [
'verbose_name', 'backend', 'allow_run_sync', 'allow_run_async',
'allow_dml', 'creator', 'changed_on_', 'database_name']
'database_name', 'backend', 'allow_run_sync', 'allow_run_async',
'allow_dml', 'creator', 'modified']
add_columns = [
'database_name', 'sqlalchemy_uri', 'cache_timeout', 'extra',
'expose_in_sqllab', 'allow_run_sync', 'allow_run_async',
Expand Down Expand Up @@ -1351,6 +1351,7 @@ def testconn(self):
engine.connect()
return json.dumps(engine.table_names(), indent=4)
except Exception as e:
logging.exception(e)
return json_error_response((
"Connection failed!\n\n"
"The error message returned was:\n{}").format(e))
Expand Down
68 changes: 38 additions & 30 deletions tests/model_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,43 +6,51 @@


class DatabaseModelTestCase(unittest.TestCase):
def test_database_for_various_backend(self):

def test_database_schema_presto(self):
sqlalchemy_uri = 'presto://presto.airbnb.io:8080/hive/default'
model = Database(sqlalchemy_uri=sqlalchemy_uri)
url = make_url(model.sqlalchemy_uri)
db = model.get_database_for_various_backend(url, None)
assert db == 'hive/default'
db = model.get_database_for_various_backend(url, 'raw_data')
assert db == 'hive/raw_data'

sqlalchemy_uri = 'redshift+psycopg2://superset:XXXXXXXXXX@redshift.airbnb.io:5439/prod'
db = make_url(model.get_sqla_engine().url).database
self.assertEquals('hive/default', db)

db = make_url(model.get_sqla_engine(schema='core_db').url).database
self.assertEquals('hive/core_db', db)

sqlalchemy_uri = 'presto://presto.airbnb.io:8080/hive'
model = Database(sqlalchemy_uri=sqlalchemy_uri)
url = make_url(model.sqlalchemy_uri)
db = model.get_database_for_various_backend(url, None)
assert db == 'prod'
db = model.get_database_for_various_backend(url, 'test')
assert db == 'prod'

sqlalchemy_uri = 'postgresql+psycopg2://superset:XXXXXXXXXX@postgres.airbnb.io:5439/prod'
db = make_url(model.get_sqla_engine().url).database
self.assertEquals('hive', db)

db = make_url(model.get_sqla_engine(schema='core_db').url).database
self.assertEquals('hive/core_db', db)

def test_database_schema_postgres(self):
sqlalchemy_uri = 'postgresql+psycopg2://postgres.airbnb.io:5439/prod'
model = Database(sqlalchemy_uri=sqlalchemy_uri)
url = make_url(model.sqlalchemy_uri)
db = model.get_database_for_various_backend(url, None)
assert db == 'prod'
db = model.get_database_for_various_backend(url, 'adhoc')
assert db == 'prod'

sqlalchemy_uri = 'hive://hive@hive.airbnb.io:10000/raw_data'
db = make_url(model.get_sqla_engine().url).database
self.assertEquals('prod', db)

db = make_url(model.get_sqla_engine(schema='foo').url).database
self.assertEquals('prod', db)

def test_database_schema_hive(self):
sqlalchemy_uri = 'hive://hive@hive.airbnb.io:10000/hive/default'
model = Database(sqlalchemy_uri=sqlalchemy_uri)
url = make_url(model.sqlalchemy_uri)
db = model.get_database_for_various_backend(url, None)
assert db == 'raw_data'
db = model.get_database_for_various_backend(url, 'adhoc')
assert db == 'adhoc'
db = make_url(model.get_sqla_engine().url).database
self.assertEquals('hive/default', db)

db = make_url(model.get_sqla_engine(schema='core_db').url).database
self.assertEquals('hive/core_db', db)

sqlalchemy_uri = 'mysql://superset:XXXXXXXXXX@mysql.airbnb.io/superset'
def test_database_schema_mysql(self):
sqlalchemy_uri = 'mysql://root@localhost/superset'
model = Database(sqlalchemy_uri=sqlalchemy_uri)
url = make_url(model.sqlalchemy_uri)
db = model.get_database_for_various_backend(url, None)
assert db == 'superset'
db = model.get_database_for_various_backend(url, 'adhoc')
assert db == 'adhoc'

db = make_url(model.get_sqla_engine().url).database
self.assertEquals('superset', db)

db = make_url(model.get_sqla_engine(schema='staging').url).database
self.assertEquals('staging', db)