Skip to content

Commit

Permalink
Add Table performance improvements (#3509)
Browse files Browse the repository at this point in the history
* Improved performance of 'Add table' function

* got rid of pvt function call

* changes metric obj to key on metric_name
  • Loading branch information
Mogball authored and mistercrunch committed Sep 25, 2017
1 parent 255ea69 commit f3146ef
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 42 deletions.
49 changes: 21 additions & 28 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
DateTime,
)
import sqlalchemy as sa
from sqlalchemy import asc, and_, desc, select
from sqlalchemy import asc, and_, desc, select, or_
from sqlalchemy.sql.expression import TextAsFrom
from sqlalchemy.orm import backref, relationship
from sqlalchemy.sql import table, literal_column, text, column
Expand Down Expand Up @@ -588,45 +588,41 @@ def fetch_metadata(self):
table = self.get_sqla_table_object()
except Exception:
raise Exception(_(
"Table doesn't seem to exist in the specified database, "
"couldn't fetch column information"))
"Table [{}] doesn't seem to exist in the specified database, "
"couldn't fetch column information").format(self.table_name))

TC = TableColumn # noqa shortcut to class
M = SqlMetric # noqa
metrics = []
any_date_col = None
db_dialect = self.database.get_sqla_engine().dialect
db_dialect = self.database.get_dialect()
dbcols = (
db.session.query(TableColumn)
.filter(TableColumn.table == self)
.filter(or_(TableColumn.column_name == col.name
for col in table.columns)))
dbcols = {dbcol.column_name: dbcol for dbcol in dbcols}

for col in table.columns:
try:
datatype = "{}".format(col.type.compile(dialect=db_dialect)).upper()
datatype = col.type.compile(dialect=db_dialect).upper()
except Exception as e:
datatype = "UNKNOWN"
logging.error(
"Unrecognized data type in {}.{}".format(table, col.name))
logging.exception(e)
dbcol = (
db.session
.query(TC)
.filter(TC.table == self)
.filter(TC.column_name == col.name)
.first()
)
db.session.flush()
dbcol = dbcols.get(col.name, None)
if not dbcol:
dbcol = TableColumn(column_name=col.name, type=datatype)
dbcol.groupby = dbcol.is_string
dbcol.filterable = dbcol.is_string
dbcol.sum = dbcol.is_num
dbcol.avg = dbcol.is_num
dbcol.is_dttm = dbcol.is_time

db.session.merge(self)
self.columns.append(dbcol)

if not any_date_col and dbcol.is_time:
any_date_col = col.name

quoted = "{}".format(col.compile(dialect=db_dialect))
quoted = str(col.compile(dialect=db_dialect))
if dbcol.sum:
metrics.append(M(
metric_name='sum__' + dbcol.column_name,
Expand Down Expand Up @@ -663,28 +659,25 @@ def fetch_metadata(self):
expression="COUNT(DISTINCT {})".format(quoted)
))
dbcol.type = datatype
db.session.merge(self)
db.session.commit()

metrics.append(M(
metric_name='count',
verbose_name='COUNT(*)',
metric_type='count',
expression="COUNT(*)"
))

dbmetrics = db.session.query(M).filter(M.table_id == self.id).filter(
or_(M.metric_name == metric.metric_name for metric in metrics))
dbmetrics = {metric.metric_name: metric for metric in dbmetrics}
for metric in metrics:
m = (
db.session.query(M)
.filter(M.metric_name == metric.metric_name)
.filter(M.table_id == self.id)
.first()
)
metric.table_id = self.id
if not m:
if not dbmetrics.get(metric.metric_name, None):
db.session.add(metric)
db.session.commit()
if not self.main_dttm_col:
self.main_dttm_col = any_date_col
db.session.merge(self)
db.session.commit()

@classmethod
def import_obj(cls, i_datasource, import_time=None):
Expand Down
23 changes: 9 additions & 14 deletions superset/connectors/sqla/views.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Views used by the SqlAlchemy connector"""
import logging

from past.builtins import basestring

from flask import Markup, flash, redirect
Expand Down Expand Up @@ -229,21 +228,17 @@ class TableModelView(DatasourceModelView, DeleteMixin): # noqa
}

def pre_add(self, table):
number_of_existing_tables = db.session.query(
sa.func.count('*')).filter(
models.SqlaTable.table_name == table.table_name,
models.SqlaTable.schema == table.schema,
models.SqlaTable.database_id == table.database.id
).scalar()
# table object is already added to the session
if number_of_existing_tables > 1:
raise Exception(get_datasource_exist_error_mgs(table.full_name))
with db.session.no_autoflush:
table_query = db.session.query(models.SqlaTable).filter(
models.SqlaTable.table_name == table.table_name,
models.SqlaTable.schema == table.schema,
models.SqlaTable.database_id == table.database.id)
if db.session.query(table_query.exists()).scalar():
raise Exception(
get_datasource_exist_error_mgs(table.full_name))

# Fail before adding if the table can't be found
try:
table.get_sqla_table_object()
except Exception as e:
logging.exception(e)
if not table.database.has_table(table):
raise Exception(_(
"Table [{}] could not be found, "
"please double check your "
Expand Down
10 changes: 10 additions & 0 deletions superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from sqlalchemy.pool import NullPool
from sqlalchemy.sql import text
from sqlalchemy.sql.expression import TextAsFrom
from sqlalchemy.engine import url
from sqlalchemy_utils import EncryptedType

from superset import app, db, db_engine_specs, utils, sm
Expand Down Expand Up @@ -743,6 +744,15 @@ def get_perm(self):
return (
"[{obj.database_name}].(id:{obj.id})").format(obj=self)

def has_table(self, table):
engine = self.get_sqla_engine()
return engine.dialect.has_table(
engine, table.table_name, table.schema or None)

def get_dialect(self):
sqla_url = url.make_url(self.sqlalchemy_uri_decrypted)
return sqla_url.get_dialect()()


sqla.event.listen(Database, 'after_insert', set_perm)
sqla.event.listen(Database, 'after_update', set_perm)
Expand Down

0 comments on commit f3146ef

Please sign in to comment.