From 2ab6a411f4b9fa75bc86b83f4aa4e3707e90e002 Mon Sep 17 00:00:00 2001 From: Bogdan Date: Mon, 9 Jan 2017 10:02:03 -0800 Subject: [PATCH] Druid dashboard import/export. (#1930) * Druid dashboard import * Fix tests. * Parse params with trailing commas * Resolved issue with datasource is not in DB yet when slice perms are set * Finish rebase * Fix heads alembic problem. --- superset/import_util.py | 74 ++ .../versions/1296d28ec131_druid_exports.py | 22 + superset/models.py | 870 +++++++++--------- superset/source_registry.py | 25 +- superset/views.py | 8 +- tests/import_export_tests.py | 151 ++- 6 files changed, 686 insertions(+), 464 deletions(-) create mode 100644 superset/import_util.py create mode 100644 superset/migrations/versions/1296d28ec131_druid_exports.py diff --git a/superset/import_util.py b/superset/import_util.py new file mode 100644 index 000000000000..e71623d57926 --- /dev/null +++ b/superset/import_util.py @@ -0,0 +1,74 @@ +import logging +from sqlalchemy.orm.session import make_transient + + +def import_datasource( + session, + i_datasource, + lookup_database, + lookup_datasource, + import_time): + """Imports the datasource from the object to the database. + + Metrics and columns and datasource will be overrided if exists. + This function can be used to import/export dashboards between multiple + superset instances. Audit metadata isn't copies over. + """ + make_transient(i_datasource) + logging.info('Started import of the datasource: {}'.format( + i_datasource.to_json())) + + i_datasource.id = None + i_datasource.database_id = lookup_database(i_datasource).id + i_datasource.alter_params(import_time=import_time) + + # override the datasource + datasource = lookup_datasource(i_datasource) + + if datasource: + datasource.override(i_datasource) + session.flush() + else: + datasource = i_datasource.copy() + session.add(datasource) + session.flush() + + for m in i_datasource.metrics: + new_m = m.copy() + new_m.table_id = datasource.id + logging.info('Importing metric {} from the datasource: {}'.format( + new_m.to_json(), i_datasource.full_name)) + imported_m = i_datasource.metric_cls.import_obj(new_m) + if (imported_m.metric_name not in + [m.metric_name for m in datasource.metrics]): + datasource.metrics.append(imported_m) + + for c in i_datasource.columns: + new_c = c.copy() + new_c.table_id = datasource.id + logging.info('Importing column {} from the datasource: {}'.format( + new_c.to_json(), i_datasource.full_name)) + imported_c = i_datasource.column_cls.import_obj(new_c) + if (imported_c.column_name not in + [c.column_name for c in datasource.columns]): + datasource.columns.append(imported_c) + session.flush() + return datasource.id + + +def import_simple_obj(session, i_obj, lookup_obj): + make_transient(i_obj) + i_obj.id = None + i_obj.table = None + + # find if the column was already imported + existing_column = lookup_obj(i_obj) + i_obj.table = None + if existing_column: + existing_column.override(i_obj) + session.flush() + return existing_column + + session.add(i_obj) + session.flush() + return i_obj diff --git a/superset/migrations/versions/1296d28ec131_druid_exports.py b/superset/migrations/versions/1296d28ec131_druid_exports.py new file mode 100644 index 000000000000..9005c03904e7 --- /dev/null +++ b/superset/migrations/versions/1296d28ec131_druid_exports.py @@ -0,0 +1,22 @@ +"""Adds params to the datasource (druid) table + +Revision ID: 1296d28ec131 +Revises: e46f2d27a08e +Create Date: 2016-12-06 17:40:40.389652 + +""" + +# revision identifiers, used by Alembic. +revision = '1296d28ec131' +down_revision = '1b2c3f7c96f9' + +from alembic import op +import sqlalchemy as sa + + +def upgrade(): + op.add_column('datasources', sa.Column('params', sa.String(length=1000), nullable=True)) + + +def downgrade(): + op.drop_column('datasources', 'params') diff --git a/superset/models.py b/superset/models.py index cef46a4c57a8..df0ada3b952a 100644 --- a/superset/models.py +++ b/superset/models.py @@ -52,7 +52,9 @@ from werkzeug.datastructures import ImmutableMultiDict -from superset import app, db, db_engine_specs, get_session, utils, sm +from superset import ( + app, db, db_engine_specs, get_session, utils, sm, import_util +) from superset.source_registry import SourceRegistry from superset.viz import viz_types from superset.jinja_context import get_template_processor @@ -95,6 +97,13 @@ def set_perm(mapper, connection, target): # noqa ) +def set_related_perm(mapper, connection, target): # noqa + src_class = target.cls_model + id_ = target.datasource_id + ds = db.session.query(src_class).filter_by(id=int(id_)).first() + target.perm = ds.perm + + class JavascriptPostAggregator(Postaggregator): def __init__(self, name, field_names, function): self.post_aggregator = { @@ -126,7 +135,9 @@ def alter_params(self, **kwargs): @property def params_dict(self): if self.params: - return json.loads(self.params) + params = re.sub(",[ \t\r\n]+}", "}", self.params) + params = re.sub(",[ \t\r\n]+\]", "]", params) + return json.loads(params) else: return {} @@ -391,18 +402,11 @@ def import_obj(cls, slc_to_import, import_time=None): slc_to_override.override(slc_to_import) session.flush() return slc_to_override.id - else: - session.add(slc_to_import) - logging.info('Final slice: {}'.format(slc_to_import.to_json())) - session.flush() - return slc_to_import.id - + session.add(slc_to_import) + logging.info('Final slice: {}'.format(slc_to_import.to_json())) + session.flush() + return slc_to_import.id -def set_related_perm(mapper, connection, target): # noqa - src_class = target.cls_model - id_ = target.datasource_id - ds = db.session.query(src_class).filter_by(id=int(id_)).first() - target.perm = ds.perm sqla.event.listen(Slice, 'before_insert', set_related_perm) sqla.event.listen(Slice, 'before_update', set_related_perm) @@ -615,7 +619,7 @@ def export_dashboards(cls, dashboard_ids): remote_id=slc.id, datasource_name=slc.datasource.name, schema=slc.datasource.name, - database_name=slc.datasource.database.database_name, + database_name=slc.datasource.database.name, ) copied_dashboard.alter_params(remote_id=dashboard_id) copied_dashboards.append(copied_dashboard) @@ -626,7 +630,7 @@ def export_dashboards(cls, dashboard_ids): db.session, dashboard_type, dashboard_id) eager_datasource.alter_params( remote_id=eager_datasource.id, - database_name=eager_datasource.database.database_name, + database_name=eager_datasource.database.name, ) make_transient(eager_datasource) eager_datasources.append(eager_datasource) @@ -898,6 +902,168 @@ def get_perm(self): sqla.event.listen(Database, 'after_update', set_perm) +class TableColumn(Model, AuditMixinNullable, ImportMixin): + + """ORM object for table columns, each table can have multiple columns""" + + __tablename__ = 'table_columns' + id = Column(Integer, primary_key=True) + table_id = Column(Integer, ForeignKey('tables.id')) + table = relationship( + 'SqlaTable', + backref=backref('columns', cascade='all, delete-orphan'), + foreign_keys=[table_id]) + column_name = Column(String(255)) + verbose_name = Column(String(1024)) + is_dttm = Column(Boolean, default=False) + is_active = Column(Boolean, default=True) + type = Column(String(32), default='') + groupby = Column(Boolean, default=False) + count_distinct = Column(Boolean, default=False) + sum = Column(Boolean, default=False) + avg = Column(Boolean, default=False) + max = Column(Boolean, default=False) + min = Column(Boolean, default=False) + filterable = Column(Boolean, default=False) + expression = Column(Text, default='') + description = Column(Text, default='') + python_date_format = Column(String(255)) + database_expression = Column(String(255)) + + num_types = ('DOUBLE', 'FLOAT', 'INT', 'BIGINT', 'LONG') + date_types = ('DATE', 'TIME') + str_types = ('VARCHAR', 'STRING', 'CHAR') + export_fields = ( + 'table_id', 'column_name', 'verbose_name', 'is_dttm', 'is_active', + 'type', 'groupby', 'count_distinct', 'sum', 'avg', 'max', 'min', + 'filterable', 'expression', 'description', 'python_date_format', + 'database_expression' + ) + + def __repr__(self): + return self.column_name + + @property + def isnum(self): + return any([t in self.type.upper() for t in self.num_types]) + + @property + def is_time(self): + return any([t in self.type.upper() for t in self.date_types]) + + @property + def is_string(self): + return any([t in self.type.upper() for t in self.str_types]) + + @property + def sqla_col(self): + name = self.column_name + if not self.expression: + col = column(self.column_name).label(name) + else: + col = literal_column(self.expression).label(name) + return col + + def get_time_filter(self, start_dttm, end_dttm): + col = self.sqla_col.label('__time') + return and_( + col >= text(self.dttm_sql_literal(start_dttm)), + col <= text(self.dttm_sql_literal(end_dttm)), + ) + + def get_timestamp_expression(self, time_grain): + """Getting the time component of the query""" + expr = self.expression or self.column_name + if not self.expression and not time_grain: + return column(expr, type_=DateTime).label(DTTM_ALIAS) + if time_grain: + pdf = self.python_date_format + if pdf in ('epoch_s', 'epoch_ms'): + # if epoch, translate to DATE using db specific conf + db_spec = self.table.database.db_engine_spec + if pdf == 'epoch_s': + expr = db_spec.epoch_to_dttm().format(col=expr) + elif pdf == 'epoch_ms': + expr = db_spec.epoch_ms_to_dttm().format(col=expr) + grain = self.table.database.grains_dict().get(time_grain, '{col}') + expr = grain.function.format(col=expr) + return literal_column(expr, type_=DateTime).label(DTTM_ALIAS) + + @classmethod + def import_obj(cls, i_column): + def lookup_obj(lookup_column): + return db.session.query(TableColumn).filter( + TableColumn.table_id == lookup_column.table_id, + TableColumn.column_name == lookup_column.column_name).first() + return import_util.import_simple_obj(db.session, i_column, lookup_obj) + + def dttm_sql_literal(self, dttm): + """Convert datetime object to a SQL expression string + + If database_expression is empty, the internal dttm + will be parsed as the string with the pattern that + the user inputted (python_date_format) + If database_expression is not empty, the internal dttm + will be parsed as the sql sentence for the database to convert + """ + + tf = self.python_date_format or '%Y-%m-%d %H:%M:%S.%f' + if self.database_expression: + return self.database_expression.format(dttm.strftime('%Y-%m-%d %H:%M:%S')) + elif tf == 'epoch_s': + return str((dttm - datetime(1970, 1, 1)).total_seconds()) + elif tf == 'epoch_ms': + return str((dttm - datetime(1970, 1, 1)).total_seconds() * 1000.0) + else: + s = self.table.database.db_engine_spec.convert_dttm( + self.type, dttm) + return s or "'{}'".format(dttm.strftime(tf)) + + +class SqlMetric(Model, AuditMixinNullable, ImportMixin): + + """ORM object for metrics, each table can have multiple metrics""" + + __tablename__ = 'sql_metrics' + id = Column(Integer, primary_key=True) + metric_name = Column(String(512)) + verbose_name = Column(String(1024)) + metric_type = Column(String(32)) + table_id = Column(Integer, ForeignKey('tables.id')) + table = relationship( + 'SqlaTable', + backref=backref('metrics', cascade='all, delete-orphan'), + foreign_keys=[table_id]) + expression = Column(Text) + description = Column(Text) + is_restricted = Column(Boolean, default=False, nullable=True) + d3format = Column(String(128)) + + export_fields = ( + 'metric_name', 'verbose_name', 'metric_type', 'table_id', 'expression', + 'description', 'is_restricted', 'd3format') + + @property + def sqla_col(self): + name = self.metric_name + return literal_column(self.expression).label(name) + + @property + def perm(self): + return ( + "{parent_name}.[{obj.metric_name}](id:{obj.id})" + ).format(obj=self, + parent_name=self.table.full_name) if self.table else None + + @classmethod + def import_obj(cls, i_metric): + def lookup_obj(lookup_metric): + return db.session.query(SqlMetric).filter( + SqlMetric.table_id == lookup_metric.table_id, + SqlMetric.metric_name == lookup_metric.metric_name).first() + return import_util.import_simple_obj(db.session, i_metric, lookup_obj) + + class SqlaTable(Model, Queryable, AuditMixinNullable, ImportMixin): """An ORM object for SqlAlchemy table references""" @@ -927,6 +1093,8 @@ class SqlaTable(Model, Queryable, AuditMixinNullable, ImportMixin): perm = Column(String(1000)) baselink = "tablemodelview" + column_cls = TableColumn + metric_cls = SqlMetric export_fields = ( 'table_name', 'main_dttm_col', 'description', 'default_endpoint', 'database_id', 'is_featured', 'offset', 'cache_timeout', 'schema', @@ -1365,140 +1533,108 @@ def fetch_metadata(self): self.main_dttm_col = any_date_col @classmethod - def import_obj(cls, datasource_to_import, import_time=None): + def import_obj(cls, i_datasource, import_time=None): """Imports the datasource from the object to the database. Metrics and columns and datasource will be overrided if exists. This function can be used to import/export dashboards between multiple superset instances. Audit metadata isn't copies over. """ - session = db.session - make_transient(datasource_to_import) - logging.info('Started import of the datasource: {}' - .format(datasource_to_import.to_json())) - - datasource_to_import.id = None - database_name = datasource_to_import.params_dict['database_name'] - datasource_to_import.database_id = session.query(Database).filter_by( - database_name=database_name).one().id - datasource_to_import.alter_params(import_time=import_time) - - # override the datasource - datasource = ( - session.query(SqlaTable).join(Database) - .filter( - SqlaTable.table_name == datasource_to_import.table_name, - SqlaTable.schema == datasource_to_import.schema, - Database.id == datasource_to_import.database_id, - ) - .first() - ) - - if datasource: - datasource.override(datasource_to_import) - session.flush() - else: - datasource = datasource_to_import.copy() - session.add(datasource) - session.flush() + def lookup_sqlatable(table): + return db.session.query(SqlaTable).join(Database).filter( + SqlaTable.table_name == table.table_name, + SqlaTable.schema == table.schema, + Database.id == table.database_id, + ).first() - for m in datasource_to_import.metrics: - new_m = m.copy() - new_m.table_id = datasource.id - logging.info('Importing metric {} from the datasource: {}'.format( - new_m.to_json(), datasource_to_import.full_name)) - imported_m = SqlMetric.import_obj(new_m) - if imported_m not in datasource.metrics: - datasource.metrics.append(imported_m) - - for c in datasource_to_import.columns: - new_c = c.copy() - new_c.table_id = datasource.id - logging.info('Importing column {} from the datasource: {}'.format( - new_c.to_json(), datasource_to_import.full_name)) - imported_c = TableColumn.import_obj(new_c) - if imported_c not in datasource.columns: - datasource.columns.append(imported_c) - db.session.flush() - - return datasource.id + def lookup_database(table): + return db.session.query(Database).filter_by( + database_name=table.params_dict['database_name']).one() + return import_util.import_datasource( + db.session, i_datasource, lookup_database, lookup_sqlatable, + import_time) sqla.event.listen(SqlaTable, 'after_insert', set_perm) sqla.event.listen(SqlaTable, 'after_update', set_perm) -class SqlMetric(Model, AuditMixinNullable, ImportMixin): +class DruidCluster(Model, AuditMixinNullable): - """ORM object for metrics, each table can have multiple metrics""" + """ORM object referencing the Druid clusters""" - __tablename__ = 'sql_metrics' - id = Column(Integer, primary_key=True) - metric_name = Column(String(512)) - verbose_name = Column(String(1024)) - metric_type = Column(String(32)) - table_id = Column(Integer, ForeignKey('tables.id')) - table = relationship( - 'SqlaTable', - backref=backref('metrics', cascade='all, delete-orphan'), - foreign_keys=[table_id]) - expression = Column(Text) - description = Column(Text) - is_restricted = Column(Boolean, default=False, nullable=True) - d3format = Column(String(128)) + __tablename__ = 'clusters' + type = "druid" - export_fields = ( - 'metric_name', 'verbose_name', 'metric_type', 'table_id', 'expression', - 'description', 'is_restricted', 'd3format') + id = Column(Integer, primary_key=True) + cluster_name = Column(String(250), unique=True) + coordinator_host = Column(String(255)) + coordinator_port = Column(Integer) + coordinator_endpoint = Column( + String(255), default='druid/coordinator/v1/metadata') + broker_host = Column(String(255)) + broker_port = Column(Integer) + broker_endpoint = Column(String(255), default='druid/v2') + metadata_last_refreshed = Column(DateTime) + cache_timeout = Column(Integer) - @property - def sqla_col(self): - name = self.metric_name - return literal_column(self.expression).label(name) + def __repr__(self): + return self.cluster_name - @property - def perm(self): - return ( - "{parent_name}.[{obj.metric_name}](id:{obj.id})" - ).format(obj=self, - parent_name=self.table.full_name) if self.table else None + def get_pydruid_client(self): + cli = PyDruid( + "http://{0}:{1}/".format(self.broker_host, self.broker_port), + self.broker_endpoint) + return cli - @classmethod - def import_obj(cls, metric_to_import): - session = db.session - make_transient(metric_to_import) - metric_to_import.id = None - - # find if the column was already imported - existing_metric = session.query(SqlMetric).filter( - SqlMetric.table_id == metric_to_import.table_id, - SqlMetric.metric_name == metric_to_import.metric_name).first() - metric_to_import.table = None - if existing_metric: - existing_metric.override(metric_to_import) - session.flush() - return existing_metric + def get_datasources(self): + endpoint = ( + "http://{obj.coordinator_host}:{obj.coordinator_port}/" + "{obj.coordinator_endpoint}/datasources" + ).format(obj=self) - session.add(metric_to_import) - session.flush() - return metric_to_import + return json.loads(requests.get(endpoint).text) + def get_druid_version(self): + endpoint = ( + "http://{obj.coordinator_host}:{obj.coordinator_port}/status" + ).format(obj=self) + return json.loads(requests.get(endpoint).text)['version'] -class TableColumn(Model, AuditMixinNullable, ImportMixin): + def refresh_datasources(self, datasource_name=None, merge_flag=False): + """Refresh metadata of all datasources in the cluster + If ``datasource_name`` is specified, only that datasource is updated + """ + self.druid_version = self.get_druid_version() + for datasource in self.get_datasources(): + if datasource not in config.get('DRUID_DATA_SOURCE_BLACKLIST'): + if not datasource_name or datasource_name == datasource: + DruidDatasource.sync_to_db(datasource, self, merge_flag) - """ORM object for table columns, each table can have multiple columns""" + @property + def perm(self): + return "[{obj.cluster_name}].(id:{obj.id})".format(obj=self) - __tablename__ = 'table_columns' + @property + def name(self): + return self.cluster_name + + +class DruidColumn(Model, AuditMixinNullable, ImportMixin): + """ORM model for storing Druid datasource column metadata""" + + __tablename__ = 'columns' id = Column(Integer, primary_key=True) - table_id = Column(Integer, ForeignKey('tables.id')) - table = relationship( - 'SqlaTable', + datasource_name = Column( + String(255), + ForeignKey('datasources.datasource_name')) + # Setting enable_typechecks=False disables polymorphic inheritance. + datasource = relationship( + 'DruidDatasource', backref=backref('columns', cascade='all, delete-orphan'), - foreign_keys=[table_id]) + enable_typechecks=False) column_name = Column(String(255)) - verbose_name = Column(String(1024)) - is_dttm = Column(Boolean, default=False) is_active = Column(Boolean, default=True) - type = Column(String(32), default='') + type = Column(String(32)) groupby = Column(Boolean, default=False) count_distinct = Column(Boolean, default=False) sum = Column(Boolean, default=False) @@ -1506,19 +1642,13 @@ class TableColumn(Model, AuditMixinNullable, ImportMixin): max = Column(Boolean, default=False) min = Column(Boolean, default=False) filterable = Column(Boolean, default=False) - expression = Column(Text, default='') - description = Column(Text, default='') - python_date_format = Column(String(255)) - database_expression = Column(String(255)) + description = Column(Text) + dimension_spec_json = Column(Text) - num_types = ('DOUBLE', 'FLOAT', 'INT', 'BIGINT', 'LONG') - date_types = ('DATE', 'TIME') - str_types = ('VARCHAR', 'STRING', 'CHAR') export_fields = ( - 'table_id', 'column_name', 'verbose_name', 'is_dttm', 'is_active', - 'type', 'groupby', 'count_distinct', 'sum', 'avg', 'max', 'min', - 'filterable', 'expression', 'description', 'python_date_format', - 'database_expression' + 'datasource_name', 'column_name', 'is_active', 'type', 'groupby', + 'count_distinct', 'sum', 'avg', 'max', 'min', 'filterable', + 'description', 'dimension_spec_json' ) def __repr__(self): @@ -1526,135 +1656,142 @@ def __repr__(self): @property def isnum(self): - return any([t in self.type.upper() for t in self.num_types]) - - @property - def is_time(self): - return any([t in self.type.upper() for t in self.date_types]) + return self.type in ('LONG', 'DOUBLE', 'FLOAT', 'INT') @property - def is_string(self): - return any([t in self.type.upper() for t in self.str_types]) + def dimension_spec(self): + if self.dimension_spec_json: + return json.loads(self.dimension_spec_json) - @property - def sqla_col(self): - name = self.column_name - if not self.expression: - col = column(self.column_name).label(name) + def generate_metrics(self): + """Generate metrics based on the column metadata""" + M = DruidMetric # noqa + metrics = [] + metrics.append(DruidMetric( + metric_name='count', + verbose_name='COUNT(*)', + metric_type='count', + json=json.dumps({'type': 'count', 'name': 'count'}) + )) + # Somehow we need to reassign this for UDAFs + if self.type in ('DOUBLE', 'FLOAT'): + corrected_type = 'DOUBLE' else: - col = literal_column(self.expression).label(name) - return col - - def get_time_filter(self, start_dttm, end_dttm): - col = self.sqla_col.label('__time') - return and_( - col >= text(self.dttm_sql_literal(start_dttm)), - col <= text(self.dttm_sql_literal(end_dttm)), - ) - - def get_timestamp_expression(self, time_grain): - """Getting the time component of the query""" - expr = self.expression or self.column_name - if not self.expression and not time_grain: - return column(expr, type_=DateTime).label(DTTM_ALIAS) - if time_grain: - pdf = self.python_date_format - if pdf in ('epoch_s', 'epoch_ms'): - # if epoch, translate to DATE using db specific conf - db_spec = self.table.database.db_engine_spec - if pdf == 'epoch_s': - expr = db_spec.epoch_to_dttm().format(col=expr) - elif pdf == 'epoch_ms': - expr = db_spec.epoch_ms_to_dttm().format(col=expr) - grain = self.table.database.grains_dict().get(time_grain, '{col}') - expr = grain.function.format(col=expr) - return literal_column(expr, type_=DateTime).label(DTTM_ALIAS) - - @classmethod - def import_obj(cls, column_to_import): - session = db.session - make_transient(column_to_import) - column_to_import.id = None - column_to_import.table = None - - # find if the column was already imported - existing_column = session.query(TableColumn).filter( - TableColumn.table_id == column_to_import.table_id, - TableColumn.column_name == column_to_import.column_name).first() - column_to_import.table = None - if existing_column: - existing_column.override(column_to_import) - session.flush() - return existing_column + corrected_type = self.type - session.add(column_to_import) - session.flush() - return column_to_import + if self.sum and self.isnum: + mt = corrected_type.lower() + 'Sum' + name = 'sum__' + self.column_name + metrics.append(DruidMetric( + metric_name=name, + metric_type='sum', + verbose_name='SUM({})'.format(self.column_name), + json=json.dumps({ + 'type': mt, 'name': name, 'fieldName': self.column_name}) + )) - def dttm_sql_literal(self, dttm): - """Convert datetime object to a SQL expression string + if self.avg and self.isnum: + mt = corrected_type.lower() + 'Avg' + name = 'avg__' + self.column_name + metrics.append(DruidMetric( + metric_name=name, + metric_type='avg', + verbose_name='AVG({})'.format(self.column_name), + json=json.dumps({ + 'type': mt, 'name': name, 'fieldName': self.column_name}) + )) - If database_expression is empty, the internal dttm - will be parsed as the string with the pattern that - the user inputted (python_date_format) - If database_expression is not empty, the internal dttm - will be parsed as the sql sentence for the database to convert - """ + if self.min and self.isnum: + mt = corrected_type.lower() + 'Min' + name = 'min__' + self.column_name + metrics.append(DruidMetric( + metric_name=name, + metric_type='min', + verbose_name='MIN({})'.format(self.column_name), + json=json.dumps({ + 'type': mt, 'name': name, 'fieldName': self.column_name}) + )) + if self.max and self.isnum: + mt = corrected_type.lower() + 'Max' + name = 'max__' + self.column_name + metrics.append(DruidMetric( + metric_name=name, + metric_type='max', + verbose_name='MAX({})'.format(self.column_name), + json=json.dumps({ + 'type': mt, 'name': name, 'fieldName': self.column_name}) + )) + if self.count_distinct: + name = 'count_distinct__' + self.column_name + if self.type == 'hyperUnique' or self.type == 'thetaSketch': + metrics.append(DruidMetric( + metric_name=name, + verbose_name='COUNT(DISTINCT {})'.format(self.column_name), + metric_type=self.type, + json=json.dumps({ + 'type': self.type, + 'name': name, + 'fieldName': self.column_name + }) + )) + else: + mt = 'count_distinct' + metrics.append(DruidMetric( + metric_name=name, + verbose_name='COUNT(DISTINCT {})'.format(self.column_name), + metric_type='count_distinct', + json=json.dumps({ + 'type': 'cardinality', + 'name': name, + 'fieldNames': [self.column_name]}) + )) + session = get_session() + new_metrics = [] + for metric in metrics: + m = ( + session.query(M) + .filter(M.metric_name == metric.metric_name) + .filter(M.datasource_name == self.datasource_name) + .filter(DruidCluster.cluster_name == self.datasource.cluster_name) + .first() + ) + metric.datasource_name = self.datasource_name + if not m: + new_metrics.append(metric) + session.add(metric) + session.flush() - tf = self.python_date_format or '%Y-%m-%d %H:%M:%S.%f' - if self.database_expression: - return self.database_expression.format(dttm.strftime('%Y-%m-%d %H:%M:%S')) - elif tf == 'epoch_s': - return str((dttm - datetime(1970, 1, 1)).total_seconds()) - elif tf == 'epoch_ms': - return str((dttm - datetime(1970, 1, 1)).total_seconds() * 1000.0) - else: - s = self.table.database.db_engine_spec.convert_dttm( - self.type, dttm) - return s or "'{}'".format(dttm.strftime(tf)) + @classmethod + def import_obj(cls, i_column): + def lookup_obj(lookup_column): + return db.session.query(DruidColumn).filter( + DruidColumn.datasource_name == lookup_column.datasource_name, + DruidColumn.column_name == lookup_column.column_name).first() + return import_util.import_simple_obj(db.session, i_column, lookup_obj) -class DruidCluster(Model, AuditMixinNullable): - """ORM object referencing the Druid clusters""" +class DruidMetric(Model, AuditMixinNullable, ImportMixin): - __tablename__ = 'clusters' - type = "druid" + """ORM object referencing Druid metrics for a datasource""" + __tablename__ = 'metrics' id = Column(Integer, primary_key=True) - cluster_name = Column(String(250), unique=True) - coordinator_host = Column(String(255)) - coordinator_port = Column(Integer) - coordinator_endpoint = Column( - String(255), default='druid/coordinator/v1/metadata') - broker_host = Column(String(255)) - broker_port = Column(Integer) - broker_endpoint = Column(String(255), default='druid/v2') - metadata_last_refreshed = Column(DateTime) - cache_timeout = Column(Integer) - - def __repr__(self): - return self.cluster_name - - def get_pydruid_client(self): - cli = PyDruid( - "http://{0}:{1}/".format(self.broker_host, self.broker_port), - self.broker_endpoint) - return cli - - def get_datasources(self): - endpoint = ( - "http://{obj.coordinator_host}:{obj.coordinator_port}/" - "{obj.coordinator_endpoint}/datasources" - ).format(obj=self) - - return json.loads(requests.get(endpoint).text) - - def get_druid_version(self): - endpoint = ( - "http://{obj.coordinator_host}:{obj.coordinator_port}/status" - ).format(obj=self) - return json.loads(requests.get(endpoint).text)['version'] + metric_name = Column(String(512)) + verbose_name = Column(String(1024)) + metric_type = Column(String(32)) + datasource_name = Column( + String(255), + ForeignKey('datasources.datasource_name')) + # Setting enable_typechecks=False disables polymorphic inheritance. + datasource = relationship( + 'DruidDatasource', + backref=backref('metrics', cascade='all, delete-orphan'), + enable_typechecks=False) + json = Column(Text) + description = Column(Text) + is_restricted = Column(Boolean, default=False, nullable=True) + d3format = Column(String(128)) def refresh_datasources(self, datasource_name=None, merge_flag=False): """Refresh metadata of all datasources in the cluster @@ -1666,17 +1803,37 @@ def refresh_datasources(self, datasource_name=None, merge_flag=False): if datasource not in config.get('DRUID_DATA_SOURCE_BLACKLIST'): if not datasource_name or datasource_name == datasource: DruidDatasource.sync_to_db(datasource, self, merge_flag) + export_fields = ( + 'metric_name', 'verbose_name', 'metric_type', 'datasource_name', + 'json', 'description', 'is_restricted', 'd3format' + ) @property - def perm(self): - return "[{obj.cluster_name}].(id:{obj.id})".format(obj=self) + def json_obj(self): + try: + obj = json.loads(self.json) + except Exception: + obj = {} + return obj @property - def name(self): - return self.cluster_name + def perm(self): + return ( + "{parent_name}.[{obj.metric_name}](id:{obj.id})" + ).format(obj=self, + parent_name=self.datasource.full_name + ) if self.datasource else None + + @classmethod + def import_obj(cls, i_metric): + def lookup_obj(lookup_metric): + return db.session.query(DruidMetric).filter( + DruidMetric.datasource_name == lookup_metric.datasource_name, + DruidMetric.metric_name == lookup_metric.metric_name).first() + return import_util.import_simple_obj(db.session, i_metric, lookup_obj) -class DruidDatasource(Model, AuditMixinNullable, Queryable): +class DruidDatasource(Model, AuditMixinNullable, Queryable, ImportMixin): """ORM object referencing Druid datasources (tables)""" @@ -1703,8 +1860,17 @@ class DruidDatasource(Model, AuditMixinNullable, Queryable): 'DruidCluster', backref='datasources', foreign_keys=[cluster_name]) offset = Column(Integer, default=0) cache_timeout = Column(Integer) + params = Column(String(1000)) perm = Column(String(1000)) + metric_cls = DruidMetric + column_cls = DruidColumn + + export_fields = ( + 'datasource_name', 'is_hidden', 'description', 'default_endpoint', + 'cluster_name', 'is_featured', 'offset', 'cache_timeout', 'params' + ) + @property def database(self): return self.cluster @@ -1782,6 +1948,27 @@ def get_metric_obj(self, metric_name): if m.metric_name == metric_name ][0] + @classmethod + def import_obj(cls, i_datasource, import_time=None): + """Imports the datasource from the object to the database. + + Metrics and columns and datasource will be overrided if exists. + This function can be used to import/export dashboards between multiple + superset instances. Audit metadata isn't copies over. + """ + def lookup_datasource(d): + return db.session.query(DruidDatasource).join(DruidCluster).filter( + DruidDatasource.datasource_name == d.datasource_name, + DruidCluster.cluster_name == d.cluster_name, + ).first() + + def lookup_cluster(d): + return db.session.query(DruidCluster).filter_by( + cluster_name=d.cluster_name).one() + return import_util.import_datasource( + db.session, i_datasource, lookup_cluster, lookup_datasource, + import_time) + @staticmethod def version_higher(v1, v2): """is v1 higher than v2 @@ -2415,183 +2602,6 @@ def wrapper(*args, **kwargs): return wrapper -class DruidMetric(Model, AuditMixinNullable): - - """ORM object referencing Druid metrics for a datasource""" - - __tablename__ = 'metrics' - id = Column(Integer, primary_key=True) - metric_name = Column(String(512)) - verbose_name = Column(String(1024)) - metric_type = Column(String(32)) - datasource_name = Column( - String(255), - ForeignKey('datasources.datasource_name')) - # Setting enable_typechecks=False disables polymorphic inheritance. - datasource = relationship( - 'DruidDatasource', - backref=backref('metrics', cascade='all, delete-orphan'), - enable_typechecks=False) - json = Column(Text) - description = Column(Text) - is_restricted = Column(Boolean, default=False, nullable=True) - d3format = Column(String(128)) - - @property - def json_obj(self): - try: - obj = json.loads(self.json) - except Exception: - obj = {} - return obj - - @property - def perm(self): - return ( - "{parent_name}.[{obj.metric_name}](id:{obj.id})" - ).format(obj=self, - parent_name=self.datasource.full_name - ) if self.datasource else None - - -class DruidColumn(Model, AuditMixinNullable): - - """ORM model for storing Druid datasource column metadata""" - - __tablename__ = 'columns' - id = Column(Integer, primary_key=True) - datasource_name = Column( - String(255), - ForeignKey('datasources.datasource_name')) - # Setting enable_typechecks=False disables polymorphic inheritance. - datasource = relationship( - 'DruidDatasource', - backref=backref('columns', cascade='all, delete-orphan'), - enable_typechecks=False) - column_name = Column(String(255)) - is_active = Column(Boolean, default=True) - type = Column(String(32)) - groupby = Column(Boolean, default=False) - count_distinct = Column(Boolean, default=False) - sum = Column(Boolean, default=False) - avg = Column(Boolean, default=False) - max = Column(Boolean, default=False) - min = Column(Boolean, default=False) - filterable = Column(Boolean, default=False) - description = Column(Text) - dimension_spec_json = Column(Text) - - def __repr__(self): - return self.column_name - - @property - def isnum(self): - return self.type in ('LONG', 'DOUBLE', 'FLOAT', 'INT') - - @property - def dimension_spec(self): - if self.dimension_spec_json: - return json.loads(self.dimension_spec_json) - - def generate_metrics(self): - """Generate metrics based on the column metadata""" - M = DruidMetric # noqa - metrics = [] - metrics.append(DruidMetric( - metric_name='count', - verbose_name='COUNT(*)', - metric_type='count', - json=json.dumps({'type': 'count', 'name': 'count'}) - )) - # Somehow we need to reassign this for UDAFs - if self.type in ('DOUBLE', 'FLOAT'): - corrected_type = 'DOUBLE' - else: - corrected_type = self.type - - if self.sum and self.isnum: - mt = corrected_type.lower() + 'Sum' - name = 'sum__' + self.column_name - metrics.append(DruidMetric( - metric_name=name, - metric_type='sum', - verbose_name='SUM({})'.format(self.column_name), - json=json.dumps({ - 'type': mt, 'name': name, 'fieldName': self.column_name}) - )) - - if self.avg and self.isnum: - mt = corrected_type.lower() + 'Avg' - name = 'avg__' + self.column_name - metrics.append(DruidMetric( - metric_name=name, - metric_type='avg', - verbose_name='AVG({})'.format(self.column_name), - json=json.dumps({ - 'type': mt, 'name': name, 'fieldName': self.column_name}) - )) - - if self.min and self.isnum: - mt = corrected_type.lower() + 'Min' - name = 'min__' + self.column_name - metrics.append(DruidMetric( - metric_name=name, - metric_type='min', - verbose_name='MIN({})'.format(self.column_name), - json=json.dumps({ - 'type': mt, 'name': name, 'fieldName': self.column_name}) - )) - if self.max and self.isnum: - mt = corrected_type.lower() + 'Max' - name = 'max__' + self.column_name - metrics.append(DruidMetric( - metric_name=name, - metric_type='max', - verbose_name='MAX({})'.format(self.column_name), - json=json.dumps({ - 'type': mt, 'name': name, 'fieldName': self.column_name}) - )) - if self.count_distinct: - name = 'count_distinct__' + self.column_name - if self.type == 'hyperUnique' or self.type == 'thetaSketch': - metrics.append(DruidMetric( - metric_name=name, - verbose_name='COUNT(DISTINCT {})'.format(self.column_name), - metric_type=self.type, - json=json.dumps({ - 'type': self.type, - 'name': name, - 'fieldName': self.column_name - }) - )) - else: - mt = 'count_distinct' - metrics.append(DruidMetric( - metric_name=name, - verbose_name='COUNT(DISTINCT {})'.format(self.column_name), - metric_type='count_distinct', - json=json.dumps({ - 'type': 'cardinality', - 'name': name, - 'fieldNames': [self.column_name]}) - )) - session = get_session() - new_metrics = [] - for metric in metrics: - m = ( - session.query(M) - .filter(M.metric_name == metric.metric_name) - .filter(M.datasource_name == self.datasource_name) - .filter(DruidCluster.cluster_name == self.datasource.cluster_name) - .first() - ) - metric.datasource_name = self.datasource_name - if not m: - new_metrics.append(metric) - session.add(metric) - session.flush() - - class FavStar(Model): __tablename__ = 'favstar' diff --git a/superset/source_registry.py b/superset/source_registry.py index 2c72157ebf0b..012d15105347 100644 --- a/superset/source_registry.py +++ b/superset/source_registry.py @@ -36,7 +36,10 @@ def get_datasource_by_name(cls, session, datasource_type, datasource_name, schema, database_name): datasource_class = SourceRegistry.sources[datasource_type] datasources = session.query(datasource_class).all() - db_ds = [d for d in datasources if d.database.name == database_name and + + # Filter datasoures that don't have database. + db_ds = [d for d in datasources if d.database and + d.database.name == database_name and d.name == datasource_name and schema == schema] return db_ds[0] @@ -65,16 +68,12 @@ def query_datasources_by_name( def get_eager_datasource(cls, session, datasource_type, datasource_id): """Returns datasource with columns and metrics.""" datasource_class = SourceRegistry.sources[datasource_type] - if datasource_type == 'table': - return ( - session.query(datasource_class) - .options( - subqueryload(datasource_class.columns), - subqueryload(datasource_class.metrics) - ) - .filter_by(id=datasource_id) - .one() + return ( + session.query(datasource_class) + .options( + subqueryload(datasource_class.columns), + subqueryload(datasource_class.metrics) ) - # TODO: support druid datasources. - return session.query(datasource_class).filter_by( - id=datasource_id).first() + .filter_by(id=datasource_id) + .one() + ) diff --git a/superset/views.py b/superset/views.py index 6cb206bfff7f..f5ddf7ff77d2 100755 --- a/superset/views.py +++ b/superset/views.py @@ -1418,8 +1418,14 @@ def import_dashboards(self): if request.method == 'POST' and f: current_tt = int(time.time()) data = pickle.load(f) + # TODO: import DRUID datasources for table in data['datasources']: - models.SqlaTable.import_obj(table, import_time=current_tt) + if table.type == 'table': + models.SqlaTable.import_obj(table, import_time=current_tt) + else: + models.DruidDatasource.import_obj( + table, import_time=current_tt) + db.session.commit() for dashboard in data['dashboards']: models.Dashboard.import_obj( dashboard, import_time=current_tt) diff --git a/tests/import_export_tests.py b/tests/import_export_tests.py index 3201ce9d3478..6c7d79753729 100644 --- a/tests/import_export_tests.py +++ b/tests/import_export_tests.py @@ -34,6 +34,9 @@ def delete_imports(cls): for table in session.query(models.SqlaTable): if 'remote_id' in table.params_dict: session.delete(table) + for datasource in session.query(models.DruidDatasource): + if 'remote_id' in datasource.params_dict: + session.delete(datasource) session.commit() @classmethod @@ -52,6 +55,11 @@ def create_slice(self, name, ds_id=None, id=None, db_name='main', 'datasource_name': table_name, 'database_name': db_name, 'schema': '', + # Test for trailing commas + "metrics": [ + "sum__signup_attempt_email", + "sum__signup_attempt_facebook", + ], } if table_name and not ds_id: @@ -79,7 +87,8 @@ def create_dashboard(self, title, id=0, slcs=[]): json_metadata=json.dumps(json_metadata) ) - def create_table(self, name, schema='', id=0, cols_names=[], metric_names=[]): + def create_table( + self, name, schema='', id=0, cols_names=[], metric_names=[]): params = {'remote_id': id, 'database_name': 'main'} table = models.SqlaTable( id=id, @@ -94,6 +103,23 @@ def create_table(self, name, schema='', id=0, cols_names=[], metric_names=[]): table.metrics.append(models.SqlMetric(metric_name=metric_name)) return table + def create_druid_datasource( + self, name, id=0, cols_names=[], metric_names=[]): + params = {'remote_id': id, 'database_name': 'druid_test'} + datasource = models.DruidDatasource( + id=id, + datasource_name=name, + cluster_name='druid_test', + params=json.dumps(params) + ) + for col_name in cols_names: + datasource.columns.append( + models.DruidColumn(column_name=col_name)) + for metric_name in metric_names: + datasource.metrics.append(models.DruidMetric( + metric_name=metric_name)) + return datasource + def get_slice(self, slc_id): return db.session.query(models.Slice).filter_by(id=slc_id).first() @@ -113,6 +139,10 @@ def get_table(self, table_id): return db.session.query(models.SqlaTable).filter_by( id=table_id).first() + def get_datasource(self, datasource_id): + return db.session.query(models.DruidDatasource).filter_by( + id=datasource_id).first() + def get_table_by_name(self, name): return db.session.query(models.SqlaTable).filter_by( table_name=name).first() @@ -147,6 +177,19 @@ def assert_table_equals(self, expected_ds, actual_ds): set([m.metric_name for m in expected_ds.metrics]), set([m.metric_name for m in actual_ds.metrics])) + def assert_datasource_equals(self, expected_ds, actual_ds): + self.assertEquals( + expected_ds.datasource_name, actual_ds.datasource_name) + self.assertEquals(expected_ds.main_dttm_col, actual_ds.main_dttm_col) + self.assertEquals(len(expected_ds.metrics), len(actual_ds.metrics)) + self.assertEquals(len(expected_ds.columns), len(actual_ds.columns)) + self.assertEquals( + set([c.column_name for c in expected_ds.columns]), + set([c.column_name for c in actual_ds.columns])) + self.assertEquals( + set([m.metric_name for m in expected_ds.metrics]), + set([m.metric_name for m in actual_ds.metrics])) + def assert_slice_equals(self, expected_slc, actual_slc): self.assertEquals(expected_slc.slice_name, actual_slc.slice_name) self.assertEquals( @@ -353,63 +396,131 @@ def test_import_override_dashboard_2_slices(self): def test_import_table_no_metadata(self): table = self.create_table('pure_table', id=10001) - imported_t_id = models.SqlaTable.import_obj(table, import_time=1989) - imported_table = self.get_table(imported_t_id) - self.assert_table_equals(table, imported_table) + imported_id = models.SqlaTable.import_obj(table, import_time=1989) + imported = self.get_table(imported_id) + self.assert_table_equals(table, imported) def test_import_table_1_col_1_met(self): table = self.create_table( 'table_1_col_1_met', id=10002, cols_names=["col1"], metric_names=["metric1"]) - imported_t_id = models.SqlaTable.import_obj(table, import_time=1990) - imported_table = self.get_table(imported_t_id) - self.assert_table_equals(table, imported_table) + imported_id = models.SqlaTable.import_obj(table, import_time=1990) + imported = self.get_table(imported_id) + self.assert_table_equals(table, imported) self.assertEquals( {'remote_id': 10002, 'import_time': 1990, 'database_name': 'main'}, - json.loads(imported_table.params)) + json.loads(imported.params)) def test_import_table_2_col_2_met(self): table = self.create_table( 'table_2_col_2_met', id=10003, cols_names=['c1', 'c2'], metric_names=['m1', 'm2']) - imported_t_id = models.SqlaTable.import_obj(table, import_time=1991) + imported_id = models.SqlaTable.import_obj(table, import_time=1991) - imported_table = self.get_table(imported_t_id) - self.assert_table_equals(table, imported_table) + imported = self.get_table(imported_id) + self.assert_table_equals(table, imported) def test_import_table_override(self): table = self.create_table( 'table_override', id=10003, cols_names=['col1'], metric_names=['m1']) - imported_t_id = models.SqlaTable.import_obj(table, import_time=1991) + imported_id = models.SqlaTable.import_obj(table, import_time=1991) table_over = self.create_table( 'table_override', id=10003, cols_names=['new_col1', 'col2', 'col3'], metric_names=['new_metric1']) - imported_table_over_id = models.SqlaTable.import_obj( + imported_over_id = models.SqlaTable.import_obj( table_over, import_time=1992) - imported_table_over = self.get_table(imported_table_over_id) - self.assertEquals(imported_t_id, imported_table_over.id) + imported_over = self.get_table(imported_over_id) + self.assertEquals(imported_id, imported_over.id) expected_table = self.create_table( 'table_override', id=10003, metric_names=['new_metric1', 'm1'], cols_names=['col1', 'new_col1', 'col2', 'col3']) - self.assert_table_equals(expected_table, imported_table_over) + self.assert_table_equals(expected_table, imported_over) def test_import_table_override_idential(self): table = self.create_table( 'copy_cat', id=10004, cols_names=['new_col1', 'col2', 'col3'], metric_names=['new_metric1']) - imported_t_id = models.SqlaTable.import_obj(table, import_time=1993) + imported_id = models.SqlaTable.import_obj(table, import_time=1993) copy_table = self.create_table( 'copy_cat', id=10004, cols_names=['new_col1', 'col2', 'col3'], metric_names=['new_metric1']) - imported_t_id_copy = models.SqlaTable.import_obj( + imported_id_copy = models.SqlaTable.import_obj( copy_table, import_time=1994) - self.assertEquals(imported_t_id, imported_t_id_copy) - self.assert_table_equals(copy_table, self.get_table(imported_t_id)) + self.assertEquals(imported_id, imported_id_copy) + self.assert_table_equals(copy_table, self.get_table(imported_id)) + + def test_import_druid_no_metadata(self): + datasource = self.create_druid_datasource('pure_druid', id=10001) + imported_id = models.DruidDatasource.import_obj( + datasource, import_time=1989) + imported = self.get_datasource(imported_id) + self.assert_datasource_equals(datasource, imported) + + def test_import_druid_1_col_1_met(self): + datasource = self.create_druid_datasource( + 'druid_1_col_1_met', id=10002, + cols_names=["col1"], metric_names=["metric1"]) + imported_id = models.DruidDatasource.import_obj( + datasource, import_time=1990) + imported = self.get_datasource(imported_id) + self.assert_datasource_equals(datasource, imported) + self.assertEquals( + {'remote_id': 10002, 'import_time': 1990, + 'database_name': 'druid_test'}, + json.loads(imported.params)) + + def test_import_druid_2_col_2_met(self): + datasource = self.create_druid_datasource( + 'druid_2_col_2_met', id=10003, cols_names=['c1', 'c2'], + metric_names=['m1', 'm2']) + imported_id = models.DruidDatasource.import_obj( + datasource, import_time=1991) + imported = self.get_datasource(imported_id) + self.assert_datasource_equals(datasource, imported) + + def test_import_druid_override(self): + datasource = self.create_druid_datasource( + 'druid_override', id=10003, cols_names=['col1'], + metric_names=['m1']) + imported_id = models.DruidDatasource.import_obj( + datasource, import_time=1991) + + table_over = self.create_druid_datasource( + 'druid_override', id=10003, + cols_names=['new_col1', 'col2', 'col3'], + metric_names=['new_metric1']) + imported_over_id = models.DruidDatasource.import_obj( + table_over, import_time=1992) + + imported_over = self.get_datasource(imported_over_id) + self.assertEquals(imported_id, imported_over.id) + expected_datasource = self.create_druid_datasource( + 'druid_override', id=10003, metric_names=['new_metric1', 'm1'], + cols_names=['col1', 'new_col1', 'col2', 'col3']) + self.assert_datasource_equals(expected_datasource, imported_over) + + def test_import_druid_override_idential(self): + datasource = self.create_druid_datasource( + 'copy_cat', id=10004, cols_names=['new_col1', 'col2', 'col3'], + metric_names=['new_metric1']) + imported_id = models.DruidDatasource.import_obj( + datasource, import_time=1993) + + copy_datasource = self.create_druid_datasource( + 'copy_cat', id=10004, cols_names=['new_col1', 'col2', 'col3'], + metric_names=['new_metric1']) + imported_id_copy = models.DruidDatasource.import_obj( + copy_datasource, import_time=1994) + + self.assertEquals(imported_id, imported_id_copy) + self.assert_datasource_equals( + copy_datasource, self.get_datasource(imported_id)) + if __name__ == '__main__': unittest.main()