From 0d05dd15be885d17be189ef3946931bd42867e28 Mon Sep 17 00:00:00 2001 From: Maxime Beauchemin Date: Mon, 14 Nov 2016 15:46:05 -0800 Subject: [PATCH] progrss --- superset/data/__init__.py | 3 - .../e46f2d27a08e_materialize_perms.py | 27 +++ superset/models.py | 27 ++- superset/security.py | 33 ++- superset/views.py | 25 ++- tests/base_tests.py | 56 +++-- tests/celery_tests.py | 6 +- tests/core_tests.py | 19 +- tests/sqllab_tests.py | 199 ++++++++++-------- 9 files changed, 258 insertions(+), 137 deletions(-) create mode 100644 superset/migrations/versions/e46f2d27a08e_materialize_perms.py diff --git a/superset/data/__init__.py b/superset/data/__init__.py index f74bb652f3e6..f88895b2f31f 100644 --- a/superset/data/__init__.py +++ b/superset/data/__init__.py @@ -891,11 +891,8 @@ def load_unicode_test_data(): dash.position_json = json.dumps([pos], indent=4) dash.slug = "unicode-test" dash.slices = [slc] - print('merge') db.session.merge(dash) - print('commit') db.session.commit() - print('after') def load_random_time_series_data(): diff --git a/superset/migrations/versions/e46f2d27a08e_materialize_perms.py b/superset/migrations/versions/e46f2d27a08e_materialize_perms.py new file mode 100644 index 000000000000..7611671fe16b --- /dev/null +++ b/superset/migrations/versions/e46f2d27a08e_materialize_perms.py @@ -0,0 +1,27 @@ +"""materialize perms + +Revision ID: e46f2d27a08e +Revises: c611f2b591b8 +Create Date: 2016-11-14 15:23:32.594898 + +""" + +# revision identifiers, used by Alembic. +revision = 'e46f2d27a08e' +down_revision = 'c611f2b591b8' + +from alembic import op +import sqlalchemy as sa + + +def upgrade(): + op.add_column('datasources', sa.Column('perm', sa.String(length=1000), nullable=True)) + op.add_column('dbs', sa.Column('perm', sa.String(length=1000), nullable=True)) + op.add_column('tables', sa.Column('perm', sa.String(length=1000), nullable=True)) + + +def downgrade(): + op.drop_column('tables', 'perm') + op.drop_column('datasources', 'perm') + op.drop_column('dbs', 'perm') + diff --git a/superset/models.py b/superset/models.py index ebe63cc70ae6..ed9e75d6a8cd 100644 --- a/superset/models.py +++ b/superset/models.py @@ -70,6 +70,10 @@ FillterPattern = re.compile(r'''((?:[^,"']|"[^"]*"|'[^']*')+)''') +def set_perm(mapper, connection, target): # noqa + target.perm = target.get_perm() + + def init_metrics_perm(metrics=None): """Create permissions for restricted metrics @@ -216,7 +220,7 @@ class Slice(Model, AuditMixinNullable, ImportMixin): params = Column(Text) description = Column(Text) cache_timeout = Column(Integer) - perm = Column(String(2000), unique=True) + perm = Column(String(1000)) owners = relationship("User", secondary=slice_user) export_fields = ('slice_name', 'datasource_type', 'datasource_name', @@ -383,14 +387,14 @@ def import_obj(cls, slc_to_import, import_time=None): return slc_to_import.id -def set_perm(mapper, connection, target): # noqa +def set_related_perm(mapper, connection, target): # noqa src_class = target.cls_model id_ = target.datasource_id ds = db.session.query(src_class).filter_by(id=int(id_)).first() target.perm = ds.perm -sqla.event.listen(Slice, 'before_insert', set_perm) -sqla.event.listen(Slice, 'before_update', set_perm) +sqla.event.listen(Slice, 'before_insert', set_related_perm) +sqla.event.listen(Slice, 'before_update', set_related_perm) dashboard_slices = Table( @@ -681,7 +685,7 @@ class Database(Model, AuditMixinNullable): "engine_params": {} } """)) - perm = Column(String(2000), unique=True) + perm = Column(String(1000)) def __repr__(self): return self.database_name @@ -849,6 +853,9 @@ def get_perm(self): return ( "[{obj.database_name}].(id:{obj.id})").format(obj=self) +sqla.event.listen(Database, 'before_insert', set_perm) +sqla.event.listen(Database, 'before_update', set_perm) + class SqlaTable(Model, Queryable, AuditMixinNullable, ImportMixin): @@ -875,7 +882,7 @@ class SqlaTable(Model, Queryable, AuditMixinNullable, ImportMixin): schema = Column(String(255)) sql = Column(Text) params = Column(Text) - perm = Column(String(2000), unique=True) + perm = Column(String(1000)) baselink = "tablemodelview" export_fields = ( @@ -1317,6 +1324,9 @@ def import_obj(cls, datasource_to_import, import_time=None): return datasource.id +sqla.event.listen(SqlaTable, 'before_insert', set_perm) +sqla.event.listen(SqlaTable, 'before_update', set_perm) + class SqlMetric(Model, AuditMixinNullable, ImportMixin): @@ -1592,7 +1602,7 @@ class DruidDatasource(Model, AuditMixinNullable, Queryable): 'DruidCluster', backref='datasources', foreign_keys=[cluster_name]) offset = Column(Integer, default=0) cache_timeout = Column(Integer) - perm = Column(String(2000), unique=True) + perm = Column(String(1000)) @property def database(self): @@ -2196,6 +2206,9 @@ def get_having_filters(self, raw_filters): filters = cond return filters +sqla.event.listen(DruidDatasource, 'before_insert', set_perm) +sqla.event.listen(DruidDatasource, 'before_update', set_perm) + class Log(Model): diff --git a/superset/security.py b/superset/security.py index e3b7fa033cbd..52fc67ea3b45 100644 --- a/superset/security.py +++ b/superset/security.py @@ -19,6 +19,7 @@ 'AccessRequestsModelView', 'Manage', 'SQL Lab', + 'Queries', 'Refresh Druid Metadata', 'ResetPasswordView', 'RoleModelView', @@ -76,6 +77,7 @@ def get_or_create_main_db(): def sync_role_definitions(): """Inits the Superset application with security roles and such""" + logging.info("Syncing role definition") # Creating default roles alpha = sm.add_role("Alpha") @@ -83,6 +85,7 @@ def sync_role_definitions(): gamma = sm.add_role("Gamma") public = sm.add_role("Public") sql_lab = sm.add_role("sql_lab") + granter = sm.add_role("granter") get_or_create_main_db() @@ -94,11 +97,11 @@ def sync_role_definitions(): perms = db.session.query(ab_models.PermissionView).all() perms = [p for p in perms if p.permission and p.view_menu] - # set admin perms + logging.info("Syncing admin perms") for p in perms: sm.add_permission_role(admin, p) - # set alpha perms + logging.info("Syncing alpha perms") for p in perms: if ( ( @@ -111,7 +114,7 @@ def sync_role_definitions(): else: sm.del_permission_role(alpha, p) - # set gamma permissions and public to be alike if specified + logging.info("Syncing gamma perms and public if specified") PUBLIC_ROLE_LIKE_GAMMA = conf.get('PUBLIC_ROLE_LIKE_GAMMA', False) for p in perms: if ( @@ -129,7 +132,7 @@ def sync_role_definitions(): sm.del_permission_role(gamma, p) sm.del_permission_role(public, p) - # Managing the sql_lab role + logging.info("Syncing sql_lab perms") for p in perms: if ( p.view_menu.name in {'SQL Lab'} or @@ -140,20 +143,30 @@ def sync_role_definitions(): else: sm.del_permission_role(sql_lab, p) - # Making sure all data source perms have been created + logging.info("Syncing granter perms") + for p in perms: + if ( + p.permission.name in { + 'can_override_role_permissions', 'can_aprove'} + ): + sm.add_permission_role(granter, p) + else: + sm.del_permission_role(granter, p) + + logging.info("Making sure all data source perms have been created") session = db.session() datasources = [ - table.perm for table in session.query(models.SqlaTable).all()] + o for o in session.query(models.SqlaTable).all()] datasources += [ - table.perm for table in session.query(models.DruidDatasource).all()] + o for o in session.query(models.DruidDatasource).all()] for datasource in datasources: perm = datasource.get_perm() sm.add_permission_view_menu('datasource_access', perm) if perm != datasource.perm: datasource.perm = perm - # Making sure all database perms have been created - databases = [o.perm for o in session.query(models.Database).all()] + logging.info("Making sure all database perms have been created") + databases = [o for o in session.query(models.Database).all()] for database in databases: perm = database.get_perm() if perm != database.perm: @@ -161,5 +174,5 @@ def sync_role_definitions(): sm.add_permission_view_menu('database_access', perm) session.commit() - # Creating metric perms + logging.info("Making sure all metrics perms exist") models.init_metrics_perm() diff --git a/superset/views.py b/superset/views.py index e240f100726b..f6e65f893d76 100755 --- a/superset/views.py +++ b/superset/views.py @@ -240,11 +240,7 @@ def apply(self, query, func): # noqa self.has_perm('all_database_access', 'all_database_access')): return query perms = self.get_view_menus('database_access') - ids = [ - o.id for o in db.session.query(self.model).all() - if o.perm in perms - ] - return query.filter(self.model.id.in_(ids)) + return query.filter(self.model.perm.in_(perms)) class DatasourceFilter(SupersetFilter): @@ -518,10 +514,11 @@ class DatabaseView(SupersetModelView, DeleteMixin): # noqa 'extra', 'database_name', 'sqlalchemy_uri', + 'perm', 'created_by', 'created_on', 'changed_by', - 'changed_on' + 'changed_on', ] add_template = "superset/models/database/add.html" edit_template = "superset/models/database/edit.html" @@ -630,6 +627,7 @@ class TableModelView(SupersetModelView, DeleteMixin): # noqa 'table_name', 'sql', 'is_featured', 'database', 'schema', 'description', 'owner', 'main_dttm_col', 'default_endpoint', 'offset', 'cache_timeout'] + show_columns = edit_columns + ['perm'] related_views = [TableColumnInlineView, SqlMetricInlineView] base_order = ('changed_on', 'desc') description_columns = { @@ -985,6 +983,19 @@ class LogModelView(SupersetModelView): icon="fa-list-ol") +class QueryView(SupersetModelView): + datamodel = SQLAInterface(models.Query) + list_columns = ['user', 'database', 'status', 'start_time', 'start_time'] + +appbuilder.add_view( + QueryView, + "Queries", + label=__("Queries"), + category="Manage", + category_label=__("Manage"), + icon="fa-search") + + class DruidDatasourceModelView(SupersetModelView, DeleteMixin): # noqa datamodel = SQLAInterface(models.DruidDatasource) list_widget = ListWidgetWithCheckboxes @@ -998,6 +1009,7 @@ class DruidDatasourceModelView(SupersetModelView, DeleteMixin): # noqa 'is_featured', 'is_hidden', 'default_endpoint', 'offset', 'cache_timeout'] add_columns = edit_columns + show_columns = add_columns + ['perm'] page_size = 500 base_order = ('datasource_name', 'asc') description_columns = { @@ -2261,7 +2273,6 @@ def show_traceback(self): title=ascii_art.stacktrace, art=ascii_art.error), 500 - @has_access @expose("/welcome") def welcome(self): """Personalized welcome page""" diff --git a/tests/base_tests.py b/tests/base_tests.py index ab69ec42291c..954d8d69b012 100644 --- a/tests/base_tests.py +++ b/tests/base_tests.py @@ -4,6 +4,7 @@ from __future__ import print_function from __future__ import unicode_literals +import logging import json import os import unittest @@ -28,14 +29,22 @@ def __init__(self, *args, **kwargs): not os.environ.get('SOLO_TEST') and not os.environ.get('examples_loaded') ): + logging.info("Loading examples") cli.load_examples(load_test_data=True) - print("Syncing role definitions") + logging.info("Done loading examples") sync_role_definitions() os.environ['examples_loaded'] = '1' + else: + sync_role_definitions() super(SupersetTestCase, self).__init__(*args, **kwargs) self.client = app.test_client() self.maxDiff = None - sync_role_definitions() + + gamma_sqllab = sm.add_role("gamma_sqllab") + for perm in sm.find_role('Gamma').permissions: + sm.add_permission_role(gamma_sqllab, perm) + for perm in sm.find_role('sql_lab').permissions: + sm.add_permission_role(gamma_sqllab, perm) admin = appbuilder.sm.find_user('admin') if not admin: @@ -51,6 +60,13 @@ def __init__(self, *args, **kwargs): appbuilder.sm.find_role('Gamma'), password='general') + gamma_sqllab = appbuilder.sm.find_user('gamma_sqllab') + if not gamma_sqllab: + gamma_sqllab = appbuilder.sm.add_user( + 'gamma_sqllab', 'gamma_sqllab', 'user', 'gamma_sqllab@fab.org', + appbuilder.sm.find_role('gamma_sqllab'), + password='general') + alpha = appbuilder.sm.find_user('alpha') if not alpha: appbuilder.sm.add_user( @@ -79,7 +95,6 @@ def __init__(self, *args, **kwargs): session.add(druid_datasource2) session.commit() - sync_role_definitions() def get_or_create(self, cls, criteria, session): obj = session.query(cls).filter_by(**criteria).first() @@ -88,11 +103,10 @@ def get_or_create(self, cls, criteria, session): return obj def login(self, username='admin', password='general'): - resp = self.client.post( + resp = self.get_resp( '/login/', - data=dict(username=username, password=password), - follow_redirects=True) - assert 'Welcome' in resp.data.decode('utf-8') + data=dict(username=username, password=password)) + self.assertIn('Welcome', resp) def get_query_by_sql(self, sql): session = db.create_scoped_session() @@ -127,14 +141,19 @@ def get_druid_ds_by_name(self, name): return db.session.query(models.DruidDatasource).filter_by( datasource_name=name).first() - def get_resp(self, url): + def get_resp(self, url, data=None, follow_redirects=True): """Shortcut to get the parsed results while following redirects""" - resp = self.client.get(url, follow_redirects=True) - return resp.data.decode('utf-8') - - def get_json_resp(self, url): + if data: + resp = self.client.post( + url, data=data, follow_redirects=follow_redirects) + return resp.data.decode('utf-8') + else: + resp = self.client.get(url, follow_redirects=follow_redirects) + return resp.data.decode('utf-8') + + def get_json_resp(self, url, data=None): """Shortcut to get the parsed results while following redirects""" - resp = self.get_resp(url) + resp = self.get_resp(url, data=data) return json.loads(resp) def get_main_database(self, session): @@ -175,13 +194,14 @@ def revoke_public_access_to_table(self, table): perm.view_menu and table.perm in perm.view_menu.name): appbuilder.sm.del_permission_role(public_role, perm) - def run_sql(self, sql, user_name, client_id): - self.login(username=(user_name if user_name else 'admin')) + def run_sql(self, sql, client_id, user_name=None): + if user_name: + self.logout() + self.login(username=(user_name if user_name else 'admin')) dbid = self.get_main_database(db.session).id - resp = self.client.post( + resp = self.get_json_resp( '/superset/sql_json/', data=dict(database_id=dbid, sql=sql, select_as_create_as=False, client_id=client_id), ) - self.logout() - return json.loads(resp.data.decode('utf-8')) + return resp diff --git a/tests/celery_tests.py b/tests/celery_tests.py index 6637237c83b6..af58afef6645 100644 --- a/tests/celery_tests.py +++ b/tests/celery_tests.py @@ -177,6 +177,7 @@ def test_add_limit_to_the_query(self): def test_run_sync_query(self): main_db = self.get_main_database(db.session) eng = main_db.get_sqla_engine() + perm_name = 'can_sql_json' db_id = main_db.id # Case 1. @@ -187,7 +188,8 @@ def test_run_sync_query(self): # Case 2. # Table and DB exists, CTA call to the backend. - sql_where = "SELECT name FROM ab_permission WHERE name='can_sql'" + sql_where = ( + "SELECT name FROM ab_permission WHERE name='{}'".format(perm_name)) result2 = self.run_sql( db_id, sql_where, "2", tmp_table='tmp_table_2', cta='true') self.assertEqual(QueryStatus.SUCCESS, result2['query']['state']) @@ -198,7 +200,7 @@ def test_run_sync_query(self): # Check the data in the tmp table. df2 = pd.read_sql_query(sql=query2.select_sql, con=eng) data2 = df2.to_dict(orient='records') - self.assertEqual([{'name': 'can_sql'}], data2) + self.assertEqual([{'name': perm_name}], data2) # Case 3. # Table and DB exists, CTA call to the backend, no data. diff --git a/tests/core_tests.py b/tests/core_tests.py index 8c484e4af919..ac6df620e724 100644 --- a/tests/core_tests.py +++ b/tests/core_tests.py @@ -97,7 +97,9 @@ def assert_admin_view_menus_in(role_name, assert_func): def test_save_slice(self): self.login(username='admin') - slice_id = self.get_slice("Energy Sankey", db.session).id + slice_name = "Energy Sankey" + slice_id = self.get_slice(slice_name, db.session).id + db.session.commit() copy_name = "Test Sankey Save" tbl_id = self.table_ids.get('energy_usage') url = ( @@ -107,9 +109,15 @@ def test_save_slice(self): "collapsed_fieldsets=&action={}&datasource_name=energy_usage&" "datasource_id=1&datasource_type=table&previous_viz_type=sankey") - db.session.commit() + # Changing name resp = self.get_resp(url.format(tbl_id, slice_id, copy_name, 'save')) assert copy_name in resp + + # Setting the name back to its original name + resp = self.get_resp(url.format(tbl_id, slice_id, slice_name, 'save')) + assert slice_name in resp + + # Doing a basic overwrite assert 'Energy' in self.get_resp( url.format(tbl_id, slice_id, copy_name, 'overwrite')) @@ -260,15 +268,15 @@ def test_gamma(self): assert "List Dashboard" in self.get_resp('/dashboardmodelview/list/') def test_csv_endpoint(self): + self.login('admin') sql = """ SELECT first_name, last_name FROM ab_user WHERE first_name='admin' """ client_id = "{}".format(random.getrandbits(64))[:10] - self.run_sql(sql, 'admin', client_id) + self.run_sql(sql, client_id) - self.login('admin') resp = self.get_resp('/superset/csv/{}'.format(client_id)) data = csv.reader(io.StringIO(resp)) expected_data = csv.reader( @@ -373,8 +381,9 @@ def test_process_template(self): self.assertEqual("SELECT '2017-01-01T00:00:00'", rendered) def test_templated_sql_json(self): + self.login('admin') sql = "SELECT '{{ datetime(2017, 1, 1).isoformat() }}' as test" - data = self.run_sql(sql, "admin", "fdaklj3ws") + data = self.run_sql(sql, "fdaklj3ws") self.assertEqual(data['data'][0]['test'], "2017-01-01T00:00:00") def test_table_metadata(self): diff --git a/tests/sqllab_tests.py b/tests/sqllab_tests.py index 139a4ec139b4..4caad3c43255 100644 --- a/tests/sqllab_tests.py +++ b/tests/sqllab_tests.py @@ -4,9 +4,8 @@ from __future__ import print_function from __future__ import unicode_literals -import csv +from datetime import datetime, timedelta import json -import io import unittest from flask_appbuilder.security.sqla import models as ab_models @@ -20,20 +19,36 @@ class SqlLabTests(SupersetTestCase): def __init__(self, *args, **kwargs): super(SqlLabTests, self).__init__(*args, **kwargs) - def setUp(self): + def run_some_queries(self): + self.logout() db.session.query(models.Query).delete() - self.run_sql("SELECT * FROM ab_user", 'admin', client_id='client_id_1') - self.run_sql("SELECT * FROM NO_TABLE", 'admin', client_id='client_id_3') - self.run_sql("SELECT * FROM ab_permission", 'gamma', client_id='client_id_2') + db.session.commit() + self.run_sql( + "SELECT * FROM ab_user", + client_id='client_id_1', + user_name='admin') + self.run_sql( + "SELECT * FROM NO_TABLE", + client_id='client_id_3', + user_name='admin') + self.run_sql( + "SELECT * FROM ab_permission", + client_id='client_id_2', + user_name='gamma_sqllab') + self.logout() def tearDown(self): db.session.query(models.Query).delete() + db.session.commit() + self.logout() def test_sql_json(self): - data = self.run_sql('SELECT * FROM ab_user', 'admin', "1") + self.login('admin') + + data = self.run_sql('SELECT * FROM ab_user', "1") self.assertLess(0, len(data['data'])) - data = self.run_sql('SELECT * FROM unexistant_table', 'admin', "2") + data = self.run_sql('SELECT * FROM unexistant_table', "2") self.assertLess(0, len(data['error'])) def test_sql_json_has_access(self): @@ -48,119 +63,133 @@ def test_sql_json_has_access(self): ) astronaut = sm.add_role("Astronaut") sm.add_permission_role(astronaut, main_db_permission_view) - # Astronaut role is Gamma + main db permissions - for gamma_perm in sm.find_role('Gamma').permissions: - sm.add_permission_role(astronaut, gamma_perm) + # Astronaut role is Gamma + sqllab + main db permissions + for perm in sm.find_role('Gamma').permissions: + sm.add_permission_role(astronaut, perm) + for perm in sm.find_role('sql_lab').permissions: + sm.add_permission_role(astronaut, perm) gagarin = appbuilder.sm.find_user('gagarin') if not gagarin: appbuilder.sm.add_user( 'gagarin', 'Iurii', 'Gagarin', 'gagarin@cosmos.ussr', - appbuilder.sm.find_role('Astronaut'), + astronaut, password='general') - data = self.run_sql('SELECT * FROM ab_user', 'gagarin', "3") + data = self.run_sql('SELECT * FROM ab_user', "3", user_name='gagarin') db.session.query(models.Query).delete() db.session.commit() self.assertLess(0, len(data['data'])) def test_queries_endpoint(self): - resp = self.client.get('/superset/queries/{}'.format(0)) + self.run_some_queries() + + # Not logged in, should error out + resp = self.client.get('/superset/queries/0') self.assertEquals(403, resp.status_code) + # Admin sees queries self.login('admin') - data = self.get_json_resp('/superset/queries/{}'.format(0)) + data = self.get_json_resp('/superset/queries/0') self.assertEquals(2, len(data)) - self.logout() - self.run_sql("SELECT * FROM ab_user1", 'admin', client_id='client_id_4') - self.run_sql("SELECT * FROM ab_user2", 'admin', client_id='client_id_5') + # Run 2 more queries + self.run_sql("SELECT * FROM ab_user1", client_id='client_id_4') + self.run_sql("SELECT * FROM ab_user2", client_id='client_id_5') self.login('admin') - data = self.get_json_resp('/superset/queries/{}'.format(0)) + data = self.get_json_resp('/superset/queries/0') self.assertEquals(4, len(data)) + now = datetime.now() + timedelta(days=1) query = db.session.query(models.Query).filter_by( sql='SELECT * FROM ab_user1').first() - query.changed_on = utils.EPOCH + query.changed_on = now db.session.commit() - data = self.get_json_resp('/superset/queries/{}'.format(123456000)) - self.assertEquals(3, len(data)) + data = self.get_json_resp( + '/superset/queries/{}'.format( + int(utils.datetime_to_epoch(now)))) + self.assertEquals(1, len(data)) self.logout() - resp = self.client.get('/superset/queries/{}'.format(0)) + resp = self.client.get('/superset/queries/0') self.assertEquals(403, resp.status_code) def test_search_query_on_db_id(self): - self.login('admin') - # Test search queries on database Id - resp = self.get_resp('/superset/search_queries?database_id=1') - data = json.loads(resp) - self.assertEquals(3, len(data)) - db_ids = [data[k]['dbId'] for k in data] - self.assertEquals([1, 1, 1], db_ids) - - resp = self.get_resp('/superset/search_queries?database_id=-1') - data = json.loads(resp) - self.assertEquals(0, len(data)) - self.logout() + self.run_some_queries() + self.login('admin') + # Test search queries on database Id + resp = self.get_resp('/superset/search_queries?database_id=1') + data = json.loads(resp) + self.assertEquals(3, len(data)) + db_ids = [data[k]['dbId'] for k in data] + self.assertEquals([1, 1, 1], db_ids) + + resp = self.get_resp('/superset/search_queries?database_id=-1') + data = json.loads(resp) + self.assertEquals(0, len(data)) def test_search_query_on_user(self): - self.login('admin') - # Test search queries on user Id - user = appbuilder.sm.find_user('admin') - resp = self.get_resp('/superset/search_queries?user_id={}'.format(user.id)) - data = json.loads(resp) - self.assertEquals(2, len(data)) - user_ids = [data[k]['userId'] for k in data] - self.assertEquals([user.id, user.id], user_ids) - - user = appbuilder.sm.find_user('gamma') - resp = self.get_resp('/superset/search_queries?user_id={}'.format(user.id)) - data = json.loads(resp) - self.assertEquals(1, len(data)) - self.assertEquals(list(data.values())[0]['userId'] , user.id) - self.logout() + self.run_some_queries() + self.login('admin') + + # Test search queries on user Id + user = appbuilder.sm.find_user('admin') + data = self.get_json_resp( + '/superset/search_queries?user_id={}'.format(user.id)) + self.assertEquals(2, len(data)) + user_ids = {data[k]['userId'] for k in data} + self.assertEquals(set([user.id]), user_ids) + + user = appbuilder.sm.find_user('gamma_sqllab') + resp = self.get_resp('/superset/search_queries?user_id={}'.format(user.id)) + data = json.loads(resp) + self.assertEquals(1, len(data)) + self.assertEquals(list(data.values())[0]['userId'] , user.id) def test_search_query_on_status(self): - self.login('admin') - # Test search queries on status - resp = self.get_resp('/superset/search_queries?status=success') - data = json.loads(resp) - self.assertEquals(2, len(data)) - states = [data[k]['state'] for k in data] - self.assertEquals(['success', 'success'], states) - - resp = self.get_resp('/superset/search_queries?status=failed') - data = json.loads(resp) - self.assertEquals(1, len(data)) - self.assertEquals(list(data.values())[0]['state'], 'failed') - self.logout() + self.run_some_queries() + self.login('admin') + # Test search queries on status + resp = self.get_resp('/superset/search_queries?status=success') + data = json.loads(resp) + self.assertEquals(2, len(data)) + states = [data[k]['state'] for k in data] + self.assertEquals(['success', 'success'], states) + + resp = self.get_resp('/superset/search_queries?status=failed') + data = json.loads(resp) + self.assertEquals(1, len(data)) + self.assertEquals(list(data.values())[0]['state'], 'failed') def test_search_query_on_text(self): - self.login('admin') - resp = self.get_resp('/superset/search_queries?search_text=permission') - data = json.loads(resp) - self.assertEquals(1, len(data)) - self.assertIn('permission', list(data.values())[0]['sql']) - self.logout() + self.run_some_queries() + self.login('admin') + resp = self.get_resp('/superset/search_queries?search_text=permission') + data = json.loads(resp) + self.assertEquals(1, len(data)) + self.assertIn('permission', list(data.values())[0]['sql']) def test_search_query_on_time(self): - self.login('admin') - first_query_time = db.session.query(models.Query).filter_by( - sql='SELECT * FROM ab_user').first().start_time - second_query_time = db.session.query(models.Query).filter_by( - sql='SELECT * FROM ab_permission').first().start_time - # Test search queries on time filter - from_time = 'from={}'.format(int(first_query_time)) - to_time = 'to={}'.format(int(second_query_time)) - params = [from_time, to_time] - resp = self.get_resp('/superset/search_queries?'+'&'.join(params)) - data = json.loads(resp) - self.assertEquals(2, len(data)) - for _, v in data.items(): - self.assertLess(int(first_query_time), v['startDttm']) - self.assertLess(v['startDttm'], int(second_query_time)) - self.logout() + self.run_some_queries() + self.login('admin') + first_query_time = ( + db.session.query(models.Query) + .filter_by(sql='SELECT * FROM ab_user').one() + ).start_time + second_query_time = ( + db.session.query(models.Query) + .filter_by(sql='SELECT * FROM ab_permission').one() + ).start_time + # Test search queries on time filter + from_time = 'from={}'.format(int(first_query_time)) + to_time = 'to={}'.format(int(second_query_time)) + params = [from_time, to_time] + resp = self.get_resp('/superset/search_queries?'+'&'.join(params)) + data = json.loads(resp) + self.assertEquals(2, len(data)) + for _, v in data.items(): + self.assertLess(int(first_query_time), v['startDttm']) + self.assertLess(v['startDttm'], int(second_query_time)) if __name__ == '__main__':