From bce02e3f518237c03273e3ed4d9d1a13d9f8f6a9 Mon Sep 17 00:00:00 2001 From: Maxime Beauchemin Date: Thu, 17 Nov 2016 11:58:33 -0800 Subject: [PATCH] [security] improving the security scheme (#1587) * [security] improving the security scheme * Addressing comments * improving docs * Creating security module to organize things * Moving CLI to its own module * perms * Materializung perms * progrss * Addressing comments, linting --- .gitignore | 1 + docs/security.rst | 21 +- superset/__init__.py | 2 + superset/assets/javascripts/SqlLab/actions.js | 2 +- .../SqlLab/components/DatabaseSelect.jsx | 20 +- superset/bin/superset | 151 +------------ superset/cli.py | 158 ++++++++++++++ superset/config.py | 1 - superset/data/__init__.py | 22 +- .../e46f2d27a08e_materialize_perms.py | 27 +++ superset/models.py | 53 ++++- superset/security.py | 178 +++++++++++++++ superset/utils.py | 131 ----------- superset/views.py | 206 +++++++++--------- tests/base_tests.py | 71 +++--- tests/celery_tests.py | 14 +- tests/core_tests.py | 53 +++-- tests/druid_tests.py | 4 +- tests/sqllab_tests.py | 201 +++++++++-------- 19 files changed, 769 insertions(+), 547 deletions(-) create mode 100755 superset/cli.py create mode 100644 superset/migrations/versions/e46f2d27a08e_materialize_perms.py create mode 100644 superset/security.py diff --git a/.gitignore b/.gitignore index 8e796ee55297..6c7253373c54 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,7 @@ changelog.sh _build _static _images +_modules superset/bin/supersetc env_py3 .eggs diff --git a/docs/security.rst b/docs/security.rst index d9041ea7d0f4..55e079dbe7d4 100644 --- a/docs/security.rst +++ b/docs/security.rst @@ -7,8 +7,19 @@ FAB provides authentication, user management, permissions and roles. Provided Roles -------------- -Superset ships with 3 roles that are handled by Superset itself. You can -assume that these 3 roles will stay up-to-date as Superset evolves. +Superset ships with a set of roles that are handled by Superset itself. +You can assume that these roles will stay up-to-date as Superset evolves. +Even though it's possible for ``Admin`` usrs to do so, it is not recommended +that you alter these roles in any way by removing +or adding permissions to them as these roles will be re-synchronized to +their original values as you run your next ``superset init`` command. + +Since it's not recommended to alter the roles described here, it's right +to assume that your security strategy should be to compose user access based +on these base roles and roles that you create. For instance you could +create a role ``Financial Analyst`` that would be made of set of permissions +to a set of data sources (tables) and/or databases. Users would then be +granted ``Gamma``, ``Financial Analyst``, and perhaps ``sql_lab``. Admin """"" @@ -33,6 +44,12 @@ mostly content consumers, though they can create slices and dashboards. Also note that when Gamma users look at the dashboards and slices list view, they will only see the objects that they have access to. +sql_lab +""""""" +The ``sql_lab`` role grants access to SQL Lab. Note that while ``Admin`` +users have access to all databases by default, both ``Alpha`` and ``Gamma`` +users need to be given access on a per database basis. + Managing Gamma per data source access ------------------------------------- diff --git a/superset/__init__.py b/superset/__init__.py index d65da960c0c1..e55841410242 100644 --- a/superset/__init__.py +++ b/superset/__init__.py @@ -23,6 +23,8 @@ app = Flask(__name__) app.config.from_object(CONFIG_MODULE) +conf = app.config + if not app.debug: # In production mode, add log handler to sys.stderr. app.logger.addHandler(logging.StreamHandler()) diff --git a/superset/assets/javascripts/SqlLab/actions.js b/superset/assets/javascripts/SqlLab/actions.js index 70a0b62a0d16..93c4afe75c4c 100644 --- a/superset/assets/javascripts/SqlLab/actions.js +++ b/superset/assets/javascripts/SqlLab/actions.js @@ -158,7 +158,7 @@ export function setNetworkStatus(networkOn) { export function addAlert(alert) { const o = Object.assign({}, alert); o.id = shortid.generate(); - return { type: ADD_ALERT, o }; + return { type: ADD_ALERT, alert: o }; } export function removeAlert(alert) { diff --git a/superset/assets/javascripts/SqlLab/components/DatabaseSelect.jsx b/superset/assets/javascripts/SqlLab/components/DatabaseSelect.jsx index a18741043473..49cbba0159b8 100644 --- a/superset/assets/javascripts/SqlLab/components/DatabaseSelect.jsx +++ b/superset/assets/javascripts/SqlLab/components/DatabaseSelect.jsx @@ -2,6 +2,13 @@ const $ = window.$ = require('jquery'); import React from 'react'; import Select from 'react-select'; +const propTypes = { + onChange: React.PropTypes.func, + actions: React.PropTypes.object, + databaseId: React.PropTypes.number, + valueRenderer: React.PropTypes.func, +}; + class DatabaseSelect extends React.PureComponent { constructor(props) { super(props); @@ -23,6 +30,12 @@ class DatabaseSelect extends React.PureComponent { const options = data.result.map((db) => ({ value: db.id, label: db.database_name })); this.setState({ databaseOptions: options, databaseLoading: false }); this.props.actions.setDatabases(data.result); + if (data.result.length === 0) { + this.props.actions.addAlert({ + bsStyle: 'danger', + msg: "It seems you don't have access to any database", + }); + } }); } render() { @@ -43,11 +56,6 @@ class DatabaseSelect extends React.PureComponent { } } -DatabaseSelect.propTypes = { - onChange: React.PropTypes.func, - actions: React.PropTypes.object, - databaseId: React.PropTypes.number, - valueRenderer: React.PropTypes.func, -}; +DatabaseSelect.propTypes = propTypes; export default DatabaseSelect; diff --git a/superset/bin/superset b/superset/bin/superset index 3d711fc826ed..169008efc5ad 100755 --- a/superset/bin/superset +++ b/superset/bin/superset @@ -4,156 +4,7 @@ from __future__ import division from __future__ import print_function from __future__ import unicode_literals -import logging -import celery -from celery.bin import worker as celery_worker -from datetime import datetime -from subprocess import Popen - -from flask_migrate import MigrateCommand -from flask_script import Manager - -import superset -from superset import app, ascii_art, db, data, utils - -config = app.config - -manager = Manager(app) -manager.add_command('db', MigrateCommand) - - -@manager.option( - '-d', '--debug', action='store_true', - help="Start the web server in debug mode") -@manager.option( - '-a', '--address', default=config.get("SUPERSET_WEBSERVER_ADDRESS"), - help="Specify the address to which to bind the web server") -@manager.option( - '-p', '--port', default=config.get("SUPERSET_WEBSERVER_PORT"), - help="Specify the port on which to run the web server") -@manager.option( - '-w', '--workers', default=config.get("SUPERSET_WORKERS", 2), - help="Number of gunicorn web server workers to fire up") -@manager.option( - '-t', '--timeout', default=config.get("SUPERSET_WEBSERVER_TIMEOUT"), - help="Specify the timeout (seconds) for the gunicorn web server") -def runserver(debug, address, port, timeout, workers): - """Starts a Superset web server""" - debug = debug or config.get("DEBUG") - if debug: - app.run( - host='0.0.0.0', - port=int(port), - threaded=True, - debug=True) - else: - cmd = ( - "gunicorn " - "-w {workers} " - "--timeout {timeout} " - "-b {address}:{port} " - "--limit-request-line 0 " - "--limit-request-field_size 0 " - "superset:app").format(**locals()) - print("Starting server with command: " + cmd) - Popen(cmd, shell=True).wait() - -@manager.command -def init(): - """Inits the Superset application""" - utils.init(superset) - -@manager.option( - '-v', '--verbose', action='store_true', - help="Show extra information") -def version(verbose): - """Prints the current version number""" - s = ( - "\n{boat}\n\n" - "-----------------------\n" - "Superset {version}\n" - "-----------------------").format( - boat=ascii_art.boat, version=config.get('VERSION_STRING')) - print(s) - if verbose: - print("[DB] : " + "{}".format(db.engine)) - -@manager.option( - '-t', '--load-test-data', action='store_true', - help="Load additional test data") -def load_examples(load_test_data): - """Loads a set of Slices and Dashboards and a supporting dataset """ - print("Loading examples into {}".format(db)) - - data.load_css_templates() - - print("Loading energy related dataset") - data.load_energy() - - print("Loading [World Bank's Health Nutrition and Population Stats]") - data.load_world_bank_health_n_pop() - - print("Loading [Birth names]") - data.load_birth_names() - - print("Loading [Random time series data]") - data.load_random_time_series_data() - - print("Loading [Random long/lat data]") - data.load_long_lat_data() - - print("Loading [Multiformat time series]") - data.load_multiformat_time_series_data() - - print("Loading [Misc Charts] dashboard") - data.load_misc_dashboard() - - if load_test_data: - print("Loading [Unicode test data]") - data.load_unicode_test_data() - -@manager.option( - '-d', '--datasource', - help=( - "Specify which datasource name to load, if omitted, all " - "datasources will be refreshed")) -def refresh_druid(datasource): - """Refresh druid datasources""" - session = db.session() - from superset import models - for cluster in session.query(models.DruidCluster).all(): - try: - cluster.refresh_datasources(datasource_name=datasource) - except Exception as e: - print( - "Error while processing cluster '{}'\n{}".format( - cluster, str(e))) - logging.exception(e) - cluster.metadata_last_refreshed = datetime.now() - print( - "Refreshed metadata from cluster " - "[" + cluster.cluster_name + "]") - session.commit() - - -@manager.command -def worker(): - """Starts a Superset worker for async SQL query execution.""" - # celery -A tasks worker --loglevel=info - print("Starting SQL Celery worker.") - if config.get('CELERY_CONFIG'): - print("Celery broker url: ") - print(config.get('CELERY_CONFIG').BROKER_URL) - - application = celery.current_app._get_current_object() - c_worker = celery_worker.worker(app=application) - options = { - 'broker': config.get('CELERY_CONFIG').BROKER_URL, - 'loglevel': 'INFO', - 'traceback': True, - } - c_worker.run(**options) - +from superset.cli import manager if __name__ == "__main__": manager.run() diff --git a/superset/cli.py b/superset/cli.py new file mode 100755 index 000000000000..3c32e14a75c8 --- /dev/null +++ b/superset/cli.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import logging +import celery +from celery.bin import worker as celery_worker +from datetime import datetime +from subprocess import Popen + +from flask_migrate import MigrateCommand +from flask_script import Manager + +from superset import app, ascii_art, db, data, security + +config = app.config + +manager = Manager(app) +manager.add_command('db', MigrateCommand) + + +@manager.option( + '-d', '--debug', action='store_true', + help="Start the web server in debug mode") +@manager.option( + '-a', '--address', default=config.get("SUPERSET_WEBSERVER_ADDRESS"), + help="Specify the address to which to bind the web server") +@manager.option( + '-p', '--port', default=config.get("SUPERSET_WEBSERVER_PORT"), + help="Specify the port on which to run the web server") +@manager.option( + '-w', '--workers', default=config.get("SUPERSET_WORKERS", 2), + help="Number of gunicorn web server workers to fire up") +@manager.option( + '-t', '--timeout', default=config.get("SUPERSET_WEBSERVER_TIMEOUT"), + help="Specify the timeout (seconds) for the gunicorn web server") +def runserver(debug, address, port, timeout, workers): + """Starts a Superset web server""" + debug = debug or config.get("DEBUG") + if debug: + app.run( + host='0.0.0.0', + port=int(port), + threaded=True, + debug=True) + else: + cmd = ( + "gunicorn " + "-w {workers} " + "--timeout {timeout} " + "-b {address}:{port} " + "--limit-request-line 0 " + "--limit-request-field_size 0 " + "superset:app").format(**locals()) + print("Starting server with command: " + cmd) + Popen(cmd, shell=True).wait() + + +@manager.command +def init(): + """Inits the Superset application""" + security.sync_role_definitions() + + +@manager.option( + '-v', '--verbose', action='store_true', + help="Show extra information") +def version(verbose): + """Prints the current version number""" + s = ( + "\n{boat}\n\n" + "-----------------------\n" + "Superset {version}\n" + "-----------------------").format( + boat=ascii_art.boat, version=config.get('VERSION_STRING')) + print(s) + if verbose: + print("[DB] : " + "{}".format(db.engine)) + + +@manager.option( + '-t', '--load-test-data', action='store_true', + help="Load additional test data") +def load_examples(load_test_data): + """Loads a set of Slices and Dashboards and a supporting dataset """ + print("Loading examples into {}".format(db)) + + data.load_css_templates() + + print("Loading energy related dataset") + data.load_energy() + + print("Loading [World Bank's Health Nutrition and Population Stats]") + data.load_world_bank_health_n_pop() + + print("Loading [Birth names]") + data.load_birth_names() + + print("Loading [Random time series data]") + data.load_random_time_series_data() + + print("Loading [Random long/lat data]") + data.load_long_lat_data() + + print("Loading [Multiformat time series]") + data.load_multiformat_time_series_data() + + print("Loading [Misc Charts] dashboard") + data.load_misc_dashboard() + + if load_test_data: + print("Loading [Unicode test data]") + data.load_unicode_test_data() + + +@manager.option( + '-d', '--datasource', + help=( + "Specify which datasource name to load, if omitted, all " + "datasources will be refreshed")) +def refresh_druid(datasource): + """Refresh druid datasources""" + session = db.session() + from superset import models + for cluster in session.query(models.DruidCluster).all(): + try: + cluster.refresh_datasources(datasource_name=datasource) + except Exception as e: + print( + "Error while processing cluster '{}'\n{}".format( + cluster, str(e))) + logging.exception(e) + cluster.metadata_last_refreshed = datetime.now() + print( + "Refreshed metadata from cluster " + "[" + cluster.cluster_name + "]") + session.commit() + + +@manager.command +def worker(): + """Starts a Superset worker for async SQL query execution.""" + # celery -A tasks worker --loglevel=info + print("Starting SQL Celery worker.") + if config.get('CELERY_CONFIG'): + print("Celery broker url: ") + print(config.get('CELERY_CONFIG').BROKER_URL) + + application = celery.current_app._get_current_object() + c_worker = celery_worker.worker(app=application) + options = { + 'broker': config.get('CELERY_CONFIG').BROKER_URL, + 'loglevel': 'INFO', + 'traceback': True, + } + c_worker.run(**options) diff --git a/superset/config.py b/superset/config.py index dcd40c39b982..b16560c08a9e 100644 --- a/superset/config.py +++ b/superset/config.py @@ -8,7 +8,6 @@ from __future__ import division from __future__ import print_function from __future__ import unicode_literals -from superset import app import json import os diff --git a/superset/data/__init__.py b/superset/data/__init__.py index 0c3ef47d30dd..f88895b2f31f 100644 --- a/superset/data/__init__.py +++ b/superset/data/__init__.py @@ -14,8 +14,8 @@ import pandas as pd from sqlalchemy import String, DateTime, Date, Float, BigInteger -import superset from superset import app, db, models, utils +from superset.security import get_or_create_main_db # Shortcuts DB = models.Database @@ -67,7 +67,7 @@ def load_energy(): tbl = TBL(table_name=tbl_name) tbl.description = "Energy consumption" tbl.is_featured = True - tbl.database = utils.get_or_create_main_db(superset) + tbl.database = get_or_create_main_db() db.session.merge(tbl) db.session.commit() tbl.fetch_metadata() @@ -194,7 +194,7 @@ def load_world_bank_health_n_pop(): tbl.description = utils.readfile(os.path.join(DATA_FOLDER, 'countries.md')) tbl.main_dttm_col = 'year' tbl.is_featured = True - tbl.database = utils.get_or_create_main_db(superset) + tbl.database = get_or_create_main_db() db.session.merge(tbl) db.session.commit() tbl.fetch_metadata() @@ -586,7 +586,7 @@ def load_birth_names(): if not obj: obj = TBL(table_name='birth_names') obj.main_dttm_col = 'ds' - obj.database = utils.get_or_create_main_db(superset) + obj.database = get_or_create_main_db() obj.is_featured = True db.session.merge(obj) db.session.commit() @@ -834,7 +834,7 @@ def load_unicode_test_data(): if not obj: obj = TBL(table_name='unicode_test') obj.main_dttm_col = 'date' - obj.database = utils.get_or_create_main_db(superset) + obj.database = get_or_create_main_db() obj.is_featured = False db.session.merge(obj) db.session.commit() @@ -872,7 +872,11 @@ def load_unicode_test_data(): merge_slice(slc) print("Creating a dashboard") - dash = db.session.query(Dash).filter_by(dashboard_title="Unicode Test").first() + dash = ( + db.session.query(Dash) + .filter_by(dashboard_title="Unicode Test") + .first() + ) if not dash: dash = Dash() @@ -913,7 +917,7 @@ def load_random_time_series_data(): if not obj: obj = TBL(table_name='random_time_series') obj.main_dttm_col = 'ds' - obj.database = utils.get_or_create_main_db(superset) + obj.database = get_or_create_main_db() obj.is_featured = False db.session.merge(obj) db.session.commit() @@ -981,7 +985,7 @@ def load_long_lat_data(): if not obj: obj = TBL(table_name='long_lat') obj.main_dttm_col = 'date' - obj.database = utils.get_or_create_main_db(superset) + obj.database = get_or_create_main_db() obj.is_featured = False db.session.merge(obj) db.session.commit() @@ -1046,7 +1050,7 @@ def load_multiformat_time_series_data(): if not obj: obj = TBL(table_name='multiformat_time_series') obj.main_dttm_col = 'ds' - obj.database = utils.get_or_create_main_db(superset) + obj.database = get_or_create_main_db() obj.is_featured = False dttm_and_expr_dict = { 'ds': [None, None], diff --git a/superset/migrations/versions/e46f2d27a08e_materialize_perms.py b/superset/migrations/versions/e46f2d27a08e_materialize_perms.py new file mode 100644 index 000000000000..7611671fe16b --- /dev/null +++ b/superset/migrations/versions/e46f2d27a08e_materialize_perms.py @@ -0,0 +1,27 @@ +"""materialize perms + +Revision ID: e46f2d27a08e +Revises: c611f2b591b8 +Create Date: 2016-11-14 15:23:32.594898 + +""" + +# revision identifiers, used by Alembic. +revision = 'e46f2d27a08e' +down_revision = 'c611f2b591b8' + +from alembic import op +import sqlalchemy as sa + + +def upgrade(): + op.add_column('datasources', sa.Column('perm', sa.String(length=1000), nullable=True)) + op.add_column('dbs', sa.Column('perm', sa.String(length=1000), nullable=True)) + op.add_column('tables', sa.Column('perm', sa.String(length=1000), nullable=True)) + + +def downgrade(): + op.drop_column('tables', 'perm') + op.drop_column('datasources', 'perm') + op.drop_column('dbs', 'perm') + diff --git a/superset/models.py b/superset/models.py index e7e5cda6ac90..8407bd1a4d0e 100644 --- a/superset/models.py +++ b/superset/models.py @@ -21,6 +21,7 @@ import sqlalchemy as sqla from sqlalchemy.engine.url import make_url from sqlalchemy.orm import subqueryload +from sqlalchemy.ext.hybrid import hybrid_property import sqlparse from dateutil.parser import parse @@ -69,6 +70,27 @@ FillterPattern = re.compile(r'''((?:[^,"']|"[^"]*"|'[^']*')+)''') +def set_perm(mapper, connection, target): # noqa + target.perm = target.get_perm() + + +def init_metrics_perm(metrics=None): + """Create permissions for restricted metrics + + :param metrics: a list of metrics to be processed, if not specified, + all metrics are processed + :type metrics: models.SqlMetric or models.DruidMetric + """ + if not metrics: + metrics = [] + for model in [SqlMetric, DruidMetric]: + metrics += list(db.session.query(model).all()) + + for metric in metrics: + if metric.is_restricted and metric.perm: + sm.add_permission_view_menu('metric_access', metric.perm) + + class JavascriptPostAggregator(Postaggregator): def __init__(self, name, field_names, function): self.post_aggregator = { @@ -198,7 +220,7 @@ class Slice(Model, AuditMixinNullable, ImportMixin): params = Column(Text) description = Column(Text) cache_timeout = Column(Integer) - perm = Column(String(2000)) + perm = Column(String(1000)) owners = relationship("User", secondary=slice_user) export_fields = ('slice_name', 'datasource_type', 'datasource_name', @@ -365,14 +387,14 @@ def import_obj(cls, slc_to_import, import_time=None): return slc_to_import.id -def set_perm(mapper, connection, target): # noqa +def set_related_perm(mapper, connection, target): # noqa src_class = target.cls_model id_ = target.datasource_id ds = db.session.query(src_class).filter_by(id=int(id_)).first() target.perm = ds.perm -sqla.event.listen(Slice, 'before_insert', set_perm) -sqla.event.listen(Slice, 'before_update', set_perm) +sqla.event.listen(Slice, 'before_insert', set_related_perm) +sqla.event.listen(Slice, 'before_update', set_related_perm) dashboard_slices = Table( @@ -663,6 +685,7 @@ class Database(Model, AuditMixinNullable): "engine_params": {} } """)) + perm = Column(String(1000)) def __repr__(self): return self.database_name @@ -826,11 +849,13 @@ def sqlalchemy_uri_decrypted(self): def sql_url(self): return '/superset/sql/{}/'.format(self.id) - @property - def perm(self): + def get_perm(self): return ( "[{obj.database_name}].(id:{obj.id})").format(obj=self) +sqla.event.listen(Database, 'before_insert', set_perm) +sqla.event.listen(Database, 'before_update', set_perm) + class SqlaTable(Model, Queryable, AuditMixinNullable, ImportMixin): @@ -857,6 +882,7 @@ class SqlaTable(Model, Queryable, AuditMixinNullable, ImportMixin): schema = Column(String(255)) sql = Column(Text) params = Column(Text) + perm = Column(String(1000)) baselink = "tablemodelview" export_fields = ( @@ -882,8 +908,7 @@ def link(self): return Markup( '{table_name}'.format(**locals())) - @property - def perm(self): + def get_perm(self): return ( "[{obj.database}].[{obj.table_name}]" "(id:{obj.id})").format(obj=self) @@ -1299,6 +1324,9 @@ def import_obj(cls, datasource_to_import, import_time=None): return datasource.id +sqla.event.listen(SqlaTable, 'before_insert', set_perm) +sqla.event.listen(SqlaTable, 'before_update', set_perm) + class SqlMetric(Model, AuditMixinNullable, ImportMixin): @@ -1574,6 +1602,7 @@ class DruidDatasource(Model, AuditMixinNullable, Queryable): 'DruidCluster', backref='datasources', foreign_keys=[cluster_name]) offset = Column(Integer, default=0) cache_timeout = Column(Integer) + perm = Column(String(1000)) @property def database(self): @@ -1597,8 +1626,7 @@ def num_cols(self): def name(self): return self.datasource_name - @property - def perm(self): + def get_perm(self): return ( "[{obj.cluster_name}].[{obj.datasource_name}]" "(id:{obj.id})").format(obj=self) @@ -2178,6 +2206,9 @@ def get_having_filters(self, raw_filters): filters = cond return filters +sqla.event.listen(DruidDatasource, 'before_insert', set_perm) +sqla.event.listen(DruidDatasource, 'before_update', set_perm) + class Log(Model): @@ -2403,7 +2434,7 @@ def generate_metrics(self): session.add(metric) session.flush() - utils.init_metrics_perm(superset, new_metrics) + init_metrics_perm(new_metrics) class FavStar(Model): diff --git a/superset/security.py b/superset/security.py new file mode 100644 index 000000000000..52fc67ea3b45 --- /dev/null +++ b/superset/security.py @@ -0,0 +1,178 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +from itertools import product +import logging +from flask_appbuilder.security.sqla import models as ab_models + +from superset import conf, db, models, sm + + +READ_ONLY_MODELVIEWS = { + 'DatabaseAsync', + 'DatabaseView', + 'DruidClusterModelView', +} +ADMIN_ONLY_VIEW_MENUES = { + 'AccessRequestsModelView', + 'Manage', + 'SQL Lab', + 'Queries', + 'Refresh Druid Metadata', + 'ResetPasswordView', + 'RoleModelView', + 'Security', + 'UserDBModelView', +} | READ_ONLY_MODELVIEWS + +ADMIN_ONLY_PERMISSIONS = { + 'all_datasource_access', + 'all_database_access', + 'datasource_access', + 'database_access', + 'can_sql_json', + 'can_override_role_permissions', + 'can_sync_druid_source', + 'can_override_role_permissions', + 'can_approve', +} +READ_ONLY_PERMISSION = { + 'can_show', + 'can_list', +} + +ALPHA_ONLY_PERMISSIONS = set([ + 'can_add', + 'can_download', + 'can_delete', + 'can_edit', + 'can_save', + 'datasource_access', + 'database_access', + 'muldelete', +]) +READ_ONLY_PRODUCT = set( + product(READ_ONLY_PERMISSION, READ_ONLY_MODELVIEWS)) + + +def get_or_create_main_db(): + logging.info("Creating database reference") + dbobj = ( + db.session.query(models.Database) + .filter_by(database_name='main') + .first() + ) + if not dbobj: + dbobj = models.Database(database_name="main") + logging.info(conf.get("SQLALCHEMY_DATABASE_URI")) + dbobj.set_sqlalchemy_uri(conf.get("SQLALCHEMY_DATABASE_URI")) + dbobj.expose_in_sqllab = True + dbobj.allow_run_sync = True + db.session.add(dbobj) + db.session.commit() + return dbobj + + +def sync_role_definitions(): + """Inits the Superset application with security roles and such""" + logging.info("Syncing role definition") + + # Creating default roles + alpha = sm.add_role("Alpha") + admin = sm.add_role("Admin") + gamma = sm.add_role("Gamma") + public = sm.add_role("Public") + sql_lab = sm.add_role("sql_lab") + granter = sm.add_role("granter") + + get_or_create_main_db() + + # Global perms + sm.add_permission_view_menu( + 'all_datasource_access', 'all_datasource_access') + sm.add_permission_view_menu('all_database_access', 'all_database_access') + + perms = db.session.query(ab_models.PermissionView).all() + perms = [p for p in perms if p.permission and p.view_menu] + + logging.info("Syncing admin perms") + for p in perms: + sm.add_permission_role(admin, p) + + logging.info("Syncing alpha perms") + for p in perms: + if ( + ( + p.view_menu.name not in ADMIN_ONLY_VIEW_MENUES and + p.permission.name not in ADMIN_ONLY_PERMISSIONS + ) or + (p.permission.name, p.view_menu.name) in READ_ONLY_PRODUCT + ): + sm.add_permission_role(alpha, p) + else: + sm.del_permission_role(alpha, p) + + logging.info("Syncing gamma perms and public if specified") + PUBLIC_ROLE_LIKE_GAMMA = conf.get('PUBLIC_ROLE_LIKE_GAMMA', False) + for p in perms: + if ( + ( + p.view_menu.name not in ADMIN_ONLY_VIEW_MENUES and + p.permission.name not in ADMIN_ONLY_PERMISSIONS and + p.permission.name not in ALPHA_ONLY_PERMISSIONS + ) or + (p.permission.name, p.view_menu.name) in READ_ONLY_PRODUCT + ): + sm.add_permission_role(gamma, p) + if PUBLIC_ROLE_LIKE_GAMMA: + sm.add_permission_role(public, p) + else: + sm.del_permission_role(gamma, p) + sm.del_permission_role(public, p) + + logging.info("Syncing sql_lab perms") + for p in perms: + if ( + p.view_menu.name in {'SQL Lab'} or + p.permission.name in { + 'can_sql_json', 'can_csv', 'can_search_queries'} + ): + sm.add_permission_role(sql_lab, p) + else: + sm.del_permission_role(sql_lab, p) + + logging.info("Syncing granter perms") + for p in perms: + if ( + p.permission.name in { + 'can_override_role_permissions', 'can_aprove'} + ): + sm.add_permission_role(granter, p) + else: + sm.del_permission_role(granter, p) + + logging.info("Making sure all data source perms have been created") + session = db.session() + datasources = [ + o for o in session.query(models.SqlaTable).all()] + datasources += [ + o for o in session.query(models.DruidDatasource).all()] + for datasource in datasources: + perm = datasource.get_perm() + sm.add_permission_view_menu('datasource_access', perm) + if perm != datasource.perm: + datasource.perm = perm + + logging.info("Making sure all database perms have been created") + databases = [o for o in session.query(models.Database).all()] + for database in databases: + perm = database.get_perm() + if perm != database.perm: + database.perm = perm + sm.add_permission_view_menu('database_access', perm) + session.commit() + + logging.info("Making sure all metrics perms exist") + models.init_metrics_perm() diff --git a/superset/utils.py b/superset/utils.py index 2747b119f786..9e14f0916e12 100644 --- a/superset/utils.py +++ b/superset/utils.py @@ -20,7 +20,6 @@ import sqlalchemy as sa from dateutil.parser import parse from flask import flash, Markup -from flask_appbuilder.security.sqla import models as ab_models import markdown as md from sqlalchemy.types import TypeDecorator, TEXT from pydruid.utils.having import Having @@ -109,23 +108,6 @@ def __get__(self, obj, objtype): return functools.partial(self.__call__, obj) -def get_or_create_main_db(superset): - db = superset.db - config = superset.app.config - DB = superset.models.Database - logging.info("Creating database reference") - dbobj = db.session.query(DB).filter_by(database_name='main').first() - if not dbobj: - dbobj = DB(database_name="main") - logging.info(config.get("SQLALCHEMY_DATABASE_URI")) - dbobj.set_sqlalchemy_uri(config.get("SQLALCHEMY_DATABASE_URI")) - dbobj.expose_in_sqllab = True - dbobj.allow_run_sync = True - db.session.add(dbobj) - db.session.commit() - return dbobj - - class DimSelector(Having): def __init__(self, **args): # Just a hack to prevent any exceptions @@ -185,12 +167,6 @@ def dttm_from_timtuple(d): d.tm_year, d.tm_mon, d.tm_mday, d.tm_hour, d.tm_min, d.tm_sec) -def merge_perm(sm, permission_name, view_menu_name): - pv = sm.find_permission_view_menu(permission_name, view_menu_name) - if not pv: - sm.add_permission_view_menu(permission_name, view_menu_name) - - def parse_human_timedelta(s): """ Returns ``datetime.datetime`` from natural language time deltas @@ -224,113 +200,6 @@ def process_result_value(self, value, dialect): return value -def init(superset): - """Inits the Superset application with security roles and such""" - ADMIN_ONLY_VIEW_MENUES = set([ - 'ResetPasswordView', - 'RoleModelView', - 'Security', - 'UserDBModelView', - 'SQL Lab', - 'AccessRequestsModelView', - 'Manage', - ]) - - ADMIN_ONLY_PERMISSIONS = set([ - 'can_sync_druid_source', - 'can_override_role_permissions', - 'can_approve', - ]) - - ALPHA_ONLY_PERMISSIONS = set([ - 'all_datasource_access', - 'can_add', - 'can_download', - 'can_delete', - 'can_edit', - 'can_save', - 'datasource_access', - 'database_access', - 'muldelete', - ]) - - db = superset.db - models = superset.models - config = superset.app.config - sm = superset.appbuilder.sm - alpha = sm.add_role("Alpha") - admin = sm.add_role("Admin") - get_or_create_main_db(superset) - - merge_perm(sm, 'all_datasource_access', 'all_datasource_access') - - perms = db.session.query(ab_models.PermissionView).all() - # set alpha and admin permissions - for perm in perms: - if ( - perm.permission and - perm.permission.name in ('datasource_access', 'database_access')): - continue - if ( - perm.view_menu and - perm.view_menu.name not in ADMIN_ONLY_VIEW_MENUES and - perm.permission and - perm.permission.name not in ADMIN_ONLY_PERMISSIONS): - - sm.add_permission_role(alpha, perm) - sm.add_permission_role(admin, perm) - - gamma = sm.add_role("Gamma") - public_role = sm.find_role("Public") - public_role_like_gamma = \ - public_role and config.get('PUBLIC_ROLE_LIKE_GAMMA', False) - - # set gamma permissions - for perm in perms: - if ( - perm.view_menu and - perm.view_menu.name not in ADMIN_ONLY_VIEW_MENUES and - perm.permission and - perm.permission.name not in ADMIN_ONLY_PERMISSIONS and - perm.permission.name not in ALPHA_ONLY_PERMISSIONS): - sm.add_permission_role(gamma, perm) - if public_role_like_gamma: - sm.add_permission_role(public_role, perm) - session = db.session() - table_perms = [ - table.perm for table in session.query(models.SqlaTable).all()] - table_perms += [ - table.perm for table in session.query(models.DruidDatasource).all()] - for table_perm in table_perms: - merge_perm(sm, 'datasource_access', table_perm) - - db_perms = [db.perm for db in session.query(models.Database).all()] - for db_perm in db_perms: - merge_perm(sm, 'database_access', db_perm) - init_metrics_perm(superset) - - -def init_metrics_perm(superset, metrics=None): - """Create permissions for restricted metrics - - :param metrics: a list of metrics to be processed, if not specified, - all metrics are processed - :type metrics: models.SqlMetric or models.DruidMetric - """ - db = superset.db - models = superset.models - sm = superset.appbuilder.sm - - if not metrics: - metrics = [] - for model in [models.SqlMetric, models.DruidMetric]: - metrics += list(db.session.query(model).all()) - - for metric in metrics: - if metric.is_restricted and metric.perm: - merge_perm(sm, 'metric_access', metric.perm) - - def datetime_f(dttm): """Formats datetime to take less room when it is recent""" if dttm: diff --git a/superset/views.py b/superset/views.py index 42ad66f9e424..05aaa76b818f 100755 --- a/superset/views.py +++ b/superset/views.py @@ -3,6 +3,7 @@ from __future__ import print_function from __future__ import unicode_literals +from datetime import datetime, timedelta import json import logging import pickle @@ -11,7 +12,6 @@ import time import traceback import zlib -from datetime import datetime, timedelta import functools import sqlalchemy as sqla @@ -28,14 +28,13 @@ from flask_appbuilder.models.sqla.filters import BaseFilter from sqlalchemy import create_engine -from werkzeug.datastructures import ImmutableMultiDict from werkzeug.routing import BaseConverter from wtforms.validators import ValidationError import superset from superset import ( appbuilder, cache, db, models, viz, utils, app, - sm, ascii_art, sql_lab, results_backend + sm, ascii_art, sql_lab, results_backend, security, ) from superset.source_registry import SourceRegistry from superset.models import DatasourceAccessRequest as DAR @@ -55,12 +54,17 @@ def all_datasource_access(self): "all_datasource_access", "all_datasource_access") def database_access(self, database): - return (self.all_datasource_access() or - self.can_access("database_access", database.perm)) + return ( + self.can_access("all_database_access", "all_database_access") or + self.can_access("database_access", database.perm) + ) def datasource_access(self, datasource): - return (self.database_access(datasource.database) or - self.can_access("datasource_access", datasource.perm)) + return ( + self.database_access(datasource.database) or + self.can_access("all_database_access", "all_database_access") or + self.can_access("datasource_access", datasource.perm) + ) class ListWidgetWithCheckboxes(ListWidget): @@ -181,47 +185,94 @@ def get_user_roles(): class SupersetFilter(BaseFilter): - def get_perms(self): - perms = [] + + """Add utility function to make BaseFilter easy and fast + + These utility function exist in the SecurityManager, but would do + a database round trip at every check. Here we cache the role objects + to be able to make multiple checks but query the db only once + """ + + def get_user_roles(self): + attr = '__get_user_roles' + if not hasattr(self, attr): + setattr(self, attr, get_user_roles()) + return getattr(self, attr) + + def get_all_permissions(self): + """Returns a set of tuples with the perm name and view menu name""" + perms = set() for role in get_user_roles(): for perm_view in role.permissions: - if perm_view.permission.name == 'datasource_access': - perms.append(perm_view.view_menu.name) + t = (perm_view.permission.name, perm_view.view_menu.name) + perms.add(t) return perms + def has_role(self, role_name_or_list): + """Whether the user has this role name""" + if not isinstance(role_name_or_list, list): + role_name_or_list = [role_name_or_list] + return any( + [r.name in role_name_or_list for r in self.get_user_roles()]) + + def has_perm(self, permission_name, view_menu_name): + """Whether the user has this perm""" + return (permission_name, view_menu_name) in self.get_all_permissions() + + def get_view_menus(self, permission_name): + """Returns the details of view_menus for a perm name""" + vm = set() + for perm_name, vm_name in self.get_all_permissions(): + if perm_name == permission_name: + vm.add(vm_name) + return vm + + def has_all_datasource_access(self): + return ( + self.has_role(['Admin', 'Alpha']) or + self.has_perm('all_datasource_access', 'all_datasource_access')) + -class TableSlice(SupersetFilter): +class DatabaseFilter(SupersetFilter): def apply(self, query, func): # noqa - if any([r.name in ('Admin', 'Alpha') for r in get_user_roles()]): + if ( + self.has_role('Admin') or + self.has_perm('all_database_access', 'all_database_access')): return query - perms = self.get_perms() - tables = [] - for perm in perms: - match = re.search(r'\(id:(\d+)\)', perm) - tables.append(match.group(1)) - qry = query.filter(self.model.id.in_(tables)) - return qry + perms = self.get_view_menus('database_access') + return query.filter(self.model.perm.in_(perms)) -class FilterSlice(SupersetFilter): +class DatasourceFilter(SupersetFilter): def apply(self, query, func): # noqa - if any([r.name in ('Admin', 'Alpha') for r in get_user_roles()]): + if self.has_all_datasource_access(): return query - qry = query.filter(self.model.perm.in_(self.get_perms())) - return qry + perms = self.get_view_menus('datasource_access') + return query.filter(self.model.perm.in_(perms)) -class FilterDashboard(SupersetFilter): +class SliceFilter(SupersetFilter): + def apply(self, query, func): # noqa + if self.has_all_datasource_access(): + return query + perms = self.get_view_menus('datasource_access') + return query.filter(self.model.perm.in_(perms)) + + +class DashboardFilter(SupersetFilter): + """List dashboards for which users have access to at least one slice""" + def apply(self, query, func): # noqa - if any([r.name in ('Admin', 'Alpha') for r in get_user_roles()]): + if self.has_all_datasource_access(): return query Slice = models.Slice # noqa Dash = models.Dashboard # noqa + datasource_perms = self.get_view_menus('datasource_access') slice_ids_qry = ( db.session .query(Slice.id) - .filter(Slice.perm.in_(self.get_perms())) + .filter(Slice.perm.in_(datasource_perms)) ) query = query.filter( Dash.id.in_( @@ -233,37 +284,6 @@ def apply(self, query, func): # noqa ) return query - -class FilterDashboardSlices(SupersetFilter): - def apply(self, query, value): # noqa - if any([r.name in ('Admin', 'Alpha') for r in get_user_roles()]): - return query - qry = query.filter(self.model.perm.in_(self.get_perms())) - return qry - - -class FilterDashboardOwners(SupersetFilter): - def apply(self, query, value): # noqa - if any([r.name in ('Admin', 'Alpha') for r in get_user_roles()]): - return query - qry = query.filter_by(id=g.user.id) - return qry - - -class FilterDruidDatasource(SupersetFilter): - def apply(self, query, func): # noqa - if any([r.name in ('Admin', 'Alpha') for r in get_user_roles()]): - return query - perms = self.get_perms() - druid_datasources = [] - for perm in perms: - match = re.search(r'\(id:(\d+)\)', perm) - if match: - druid_datasources.append(match.group(1)) - qry = query.filter(self.model.id.in_(druid_datasources)) - return qry - - def validate_json(form, field): # noqa try: json.loads(field.data) @@ -494,10 +514,11 @@ class DatabaseView(SupersetModelView, DeleteMixin): # noqa 'extra', 'database_name', 'sqlalchemy_uri', + 'perm', 'created_by', 'created_on', 'changed_by', - 'changed_on' + 'changed_on', ] add_template = "superset/models/database/add.html" edit_template = "superset/models/database/edit.html" @@ -551,7 +572,7 @@ class DatabaseView(SupersetModelView, DeleteMixin): # noqa def pre_add(self, db): db.set_sqlalchemy_uri(db.sqlalchemy_uri) - utils.merge_perm(sm, 'database_access', db.perm) + security.merge_perm(sm, 'database_access', db.perm) def pre_update(self, db): self.pre_add(db) @@ -578,6 +599,7 @@ def pre_update(self, db): class DatabaseAsync(DatabaseView): + base_filters = [['id', DatabaseFilter, lambda: []]] list_columns = [ 'id', 'database_name', 'expose_in_sqllab', 'allow_ctas', 'force_ctas_schema', @@ -605,6 +627,7 @@ class TableModelView(SupersetModelView, DeleteMixin): # noqa 'table_name', 'sql', 'is_featured', 'database', 'schema', 'description', 'owner', 'main_dttm_col', 'default_endpoint', 'offset', 'cache_timeout'] + show_columns = edit_columns + ['perm'] related_views = [TableColumnInlineView, SqlMetricInlineView] base_order = ('changed_on', 'desc') description_columns = { @@ -622,7 +645,7 @@ class TableModelView(SupersetModelView, DeleteMixin): # noqa "run a query against this string as a subquery." ), } - base_filters = [['id', TableSlice, lambda: []]] + base_filters = [['id', DatasourceFilter, lambda: []]] label_columns = { 'link': _("Table"), 'changed_by_': _("Changed By"), @@ -659,7 +682,7 @@ def pre_add(self, table): def post_add(self, table): table.fetch_metadata() - utils.merge_perm(sm, 'datasource_access', table.perm) + security.merge_perm(sm, 'datasource_access', table.perm) flash(_( "The table was created. As part of this two phase configuration " "process, you should now click the edit button by " @@ -725,7 +748,7 @@ class DruidClusterModelView(SupersetModelView, DeleteMixin): # noqa } def pre_add(self, cluster): - utils.merge_perm(sm, 'database_access', cluster.perm) + security.merge_perm(sm, 'database_access', cluster.perm) def pre_update(self, cluster): self.pre_add(cluster) @@ -769,7 +792,7 @@ class SliceModelView(SupersetModelView, DeleteMixin): # noqa "Duration (in seconds) of the caching timeout for this slice." ), } - base_filters = [['id', FilterSlice, lambda: []]] + base_filters = [['id', SliceFilter, lambda: []]] label_columns = { 'cache_timeout': _("Cache Timeout"), 'creator': _("Creator"), @@ -865,15 +888,11 @@ class DashboardModelView(SupersetModelView, DeleteMixin): # noqa "want to alter specific parameters."), 'owners': _("Owners is a list of users who can alter the dashboard."), } - base_filters = [['slice', FilterDashboard, lambda: []]] + base_filters = [['slice', DashboardFilter, lambda: []]] add_form_query_rel_fields = { - 'slices': [['slices', FilterDashboardSlices, None]], - 'owners': [['owners', FilterDashboardOwners, None]], - } - edit_form_query_rel_fields = { - 'slices': [['slices', FilterDashboardSlices, None]], - 'owners': [['owners', FilterDashboardOwners, None]], + 'slices': [['slices', SliceFilter, None]], } + edit_form_query_rel_fields = add_form_query_rel_fields label_columns = { 'dashboard_link': _("Dashboard"), 'dashboard_title': _("Title"), @@ -964,6 +983,19 @@ class LogModelView(SupersetModelView): icon="fa-list-ol") +class QueryView(SupersetModelView): + datamodel = SQLAInterface(models.Query) + list_columns = ['user', 'database', 'status', 'start_time', 'end_time'] + +appbuilder.add_view( + QueryView, + "Queries", + label=__("Queries"), + category="Manage", + category_label=__("Manage"), + icon="fa-search") + + class DruidDatasourceModelView(SupersetModelView, DeleteMixin): # noqa datamodel = SQLAInterface(models.DruidDatasource) list_widget = ListWidgetWithCheckboxes @@ -977,6 +1009,7 @@ class DruidDatasourceModelView(SupersetModelView, DeleteMixin): # noqa 'is_featured', 'is_hidden', 'default_endpoint', 'offset', 'cache_timeout'] add_columns = edit_columns + show_columns = add_columns + ['perm'] page_size = 500 base_order = ('datasource_name', 'asc') description_columns = { @@ -985,7 +1018,7 @@ class DruidDatasourceModelView(SupersetModelView, DeleteMixin): # noqa "Supports markdown"), } - base_filters = [['id', FilterDruidDatasource, lambda: []]] + base_filters = [['id', DatasourceFilter, lambda: []]] label_columns = { 'datasource_link': _("Data Source"), 'cluster': _("Cluster"), @@ -1013,7 +1046,7 @@ def pre_add(self, datasource): def post_add(self, datasource): datasource.generate_metrics() - utils.merge_perm(sm, 'datasource_access', datasource.perm) + security.merge_perm(sm, 'datasource_access', datasource.perm) def post_update(self, datasource): self.post_add(datasource) @@ -1073,7 +1106,7 @@ def msg(self): class Superset(BaseSupersetView): """The base views for Superset!""" - @has_access + @has_access_api @expose("/override_role_permissions/", methods=['POST']) def override_role_permissions(self): """Updates the role with the give datasource permissions. @@ -1863,29 +1896,6 @@ def sqllab_viz(self): url = '/superset/explore/table/{table.id}/?{params}'.format(**locals()) return redirect(url) - @has_access - @expose("/sql//") - @log_this - def sql(self, database_id): - if not self.all_datasource_access(): - flash(ALL_DATASOURCE_ACCESS_ERR, "danger") - return redirect("/tablemodelview/list/") - - mydb = db.session.query( - models.Database).filter_by(id=database_id).first() - if not self.database_access(mydb.perm): - flash(get_database_access_error_msg(mydb.database_name), "danger") - return redirect("/tablemodelview/list/") - engine = mydb.get_sqla_engine() - tables = engine.table_names() - - table_name = request.args.get('table_name') - return self.render_template( - "superset/sql.html", - tables=tables, - table_name=table_name, - db=mydb) - @has_access @expose("/table////") @log_this @@ -2284,7 +2294,6 @@ def show_traceback(self): title=ascii_art.stacktrace, art=ascii_art.error), 500 - @has_access @expose("/welcome") def welcome(self): """Personalized welcome page""" @@ -2301,6 +2310,7 @@ def sqlanvil(self): if config['DRUID_IS_ACTIVE']: appbuilder.add_link( "Refresh Druid Metadata", + label=__("Refresh Druid Metadata"), href='/superset/refresh_datasources/', category='Sources', category_label=__("Sources"), diff --git a/tests/base_tests.py b/tests/base_tests.py index 74bd3d2f8873..954d8d69b012 100644 --- a/tests/base_tests.py +++ b/tests/base_tests.py @@ -4,20 +4,19 @@ from __future__ import print_function from __future__ import unicode_literals -import imp +import logging import json import os import unittest from flask_appbuilder.security.sqla import models as ab_models -import superset -from superset import app, db, models, utils, appbuilder, sm +from superset import app, cli, db, models, appbuilder, sm +from superset.security import sync_role_definitions os.environ['SUPERSET_CONFIG'] = 'tests.superset_test_config' BASE_DIR = app.config.get("BASE_DIR") -cli = imp.load_source('cli', BASE_DIR + "/bin/superset") class SupersetTestCase(unittest.TestCase): @@ -30,13 +29,22 @@ def __init__(self, *args, **kwargs): not os.environ.get('SOLO_TEST') and not os.environ.get('examples_loaded') ): + logging.info("Loading examples") cli.load_examples(load_test_data=True) - utils.init(superset) + logging.info("Done loading examples") + sync_role_definitions() os.environ['examples_loaded'] = '1' + else: + sync_role_definitions() super(SupersetTestCase, self).__init__(*args, **kwargs) self.client = app.test_client() self.maxDiff = None - utils.init(superset) + + gamma_sqllab = sm.add_role("gamma_sqllab") + for perm in sm.find_role('Gamma').permissions: + sm.add_permission_role(gamma_sqllab, perm) + for perm in sm.find_role('sql_lab').permissions: + sm.add_permission_role(gamma_sqllab, perm) admin = appbuilder.sm.find_user('admin') if not admin: @@ -52,6 +60,13 @@ def __init__(self, *args, **kwargs): appbuilder.sm.find_role('Gamma'), password='general') + gamma_sqllab = appbuilder.sm.find_user('gamma_sqllab') + if not gamma_sqllab: + gamma_sqllab = appbuilder.sm.add_user( + 'gamma_sqllab', 'gamma_sqllab', 'user', 'gamma_sqllab@fab.org', + appbuilder.sm.find_role('gamma_sqllab'), + password='general') + alpha = appbuilder.sm.find_user('alpha') if not alpha: appbuilder.sm.add_user( @@ -80,7 +95,6 @@ def __init__(self, *args, **kwargs): session.add(druid_datasource2) session.commit() - utils.init(superset) def get_or_create(self, cls, criteria, session): obj = session.query(cls).filter_by(**criteria).first() @@ -89,11 +103,10 @@ def get_or_create(self, cls, criteria, session): return obj def login(self, username='admin', password='general'): - resp = self.client.post( + resp = self.get_resp( '/login/', - data=dict(username=username, password=password), - follow_redirects=True) - assert 'Welcome' in resp.data.decode('utf-8') + data=dict(username=username, password=password)) + self.assertIn('Welcome', resp) def get_query_by_sql(self, sql): session = db.create_scoped_session() @@ -128,14 +141,19 @@ def get_druid_ds_by_name(self, name): return db.session.query(models.DruidDatasource).filter_by( datasource_name=name).first() - def get_resp(self, url): + def get_resp(self, url, data=None, follow_redirects=True): """Shortcut to get the parsed results while following redirects""" - resp = self.client.get(url, follow_redirects=True) - return resp.data.decode('utf-8') - - def get_json_resp(self, url): + if data: + resp = self.client.post( + url, data=data, follow_redirects=follow_redirects) + return resp.data.decode('utf-8') + else: + resp = self.client.get(url, follow_redirects=follow_redirects) + return resp.data.decode('utf-8') + + def get_json_resp(self, url, data=None): """Shortcut to get the parsed results while following redirects""" - resp = self.get_resp(url) + resp = self.get_resp(url, data=data) return json.loads(resp) def get_main_database(self, session): @@ -160,29 +178,30 @@ def get_access_requests(self, username, ds_type, ds_id): def logout(self): self.client.get('/logout/', follow_redirects=True) - def setup_public_access_for_dashboard(self, table_name): + def grant_public_access_to_table(self, table): public_role = appbuilder.sm.find_role('Public') perms = db.session.query(ab_models.PermissionView).all() for perm in perms: if (perm.permission.name == 'datasource_access' and - perm.view_menu and table_name in perm.view_menu.name): + perm.view_menu and table.perm in perm.view_menu.name): appbuilder.sm.add_permission_role(public_role, perm) - def revoke_public_access(self, table_name): + def revoke_public_access_to_table(self, table): public_role = appbuilder.sm.find_role('Public') perms = db.session.query(ab_models.PermissionView).all() for perm in perms: if (perm.permission.name == 'datasource_access' and - perm.view_menu and table_name in perm.view_menu.name): + perm.view_menu and table.perm in perm.view_menu.name): appbuilder.sm.del_permission_role(public_role, perm) - def run_sql(self, sql, user_name, client_id): - self.login(username=(user_name if user_name else 'admin')) + def run_sql(self, sql, client_id, user_name=None): + if user_name: + self.logout() + self.login(username=(user_name if user_name else 'admin')) dbid = self.get_main_database(db.session).id - resp = self.client.post( + resp = self.get_json_resp( '/superset/sql_json/', data=dict(database_id=dbid, sql=sql, select_as_create_as=False, client_id=client_id), ) - self.logout() - return json.loads(resp.data.decode('utf-8')) + return resp diff --git a/tests/celery_tests.py b/tests/celery_tests.py index c282c073c278..af58afef6645 100644 --- a/tests/celery_tests.py +++ b/tests/celery_tests.py @@ -4,7 +4,6 @@ from __future__ import print_function from __future__ import unicode_literals -import imp import json import os import subprocess @@ -13,15 +12,14 @@ import pandas as pd -import superset -from superset import app, appbuilder, db, models, sql_lab, utils, dataframe +from superset import app, appbuilder, cli, db, models, sql_lab, dataframe +from superset.security import sync_role_definitions from .base_tests import SupersetTestCase QueryStatus = models.QueryStatus BASE_DIR = app.config.get('BASE_DIR') -cli = imp.load_source('cli', BASE_DIR + '/bin/superset') class CeleryConfig(object): @@ -99,7 +97,7 @@ def setUpClass(cls): except OSError as e: app.logger.warn(str(e)) - utils.init(superset) + sync_role_definitions() worker_command = BASE_DIR + '/bin/superset worker' subprocess.Popen( @@ -179,6 +177,7 @@ def test_add_limit_to_the_query(self): def test_run_sync_query(self): main_db = self.get_main_database(db.session) eng = main_db.get_sqla_engine() + perm_name = 'can_sql_json' db_id = main_db.id # Case 1. @@ -189,7 +188,8 @@ def test_run_sync_query(self): # Case 2. # Table and DB exists, CTA call to the backend. - sql_where = "SELECT name FROM ab_permission WHERE name='can_sql'" + sql_where = ( + "SELECT name FROM ab_permission WHERE name='{}'".format(perm_name)) result2 = self.run_sql( db_id, sql_where, "2", tmp_table='tmp_table_2', cta='true') self.assertEqual(QueryStatus.SUCCESS, result2['query']['state']) @@ -200,7 +200,7 @@ def test_run_sync_query(self): # Check the data in the tmp table. df2 = pd.read_sql_query(sql=query2.select_sql, con=eng) data2 = df2.to_dict(orient='records') - self.assertEqual([{'name': 'can_sql'}], data2) + self.assertEqual([{'name': perm_name}], data2) # Case 3. # Table and DB exists, CTA call to the backend, no data. diff --git a/tests/core_tests.py b/tests/core_tests.py index 5b3e23348f01..b848ba853b03 100644 --- a/tests/core_tests.py +++ b/tests/core_tests.py @@ -24,8 +24,6 @@ class CoreTests(SupersetTestCase): requires_examples = True def __init__(self, *args, **kwargs): - # Load examples first, so that we setup proper permission-view - # relations for all example data sources. super(CoreTests, self).__init__(*args, **kwargs) @classmethod @@ -118,7 +116,9 @@ def test_update_explore(self): def test_save_slice(self): self.login(username='admin') - slice_id = self.get_slice("Energy Sankey", db.session).id + slice_name = "Energy Sankey" + slice_id = self.get_slice(slice_name, db.session).id + db.session.commit() copy_name = "Test Sankey Save" tbl_id = self.table_ids.get('energy_usage') url = ( @@ -128,9 +128,15 @@ def test_save_slice(self): "collapsed_fieldsets=&action={}&datasource_name=energy_usage&" "datasource_id=1&datasource_type=table&previous_viz_type=sankey") - db.session.commit() + # Changing name resp = self.get_resp(url.format(tbl_id, slice_id, copy_name, 'save')) assert copy_name in resp + + # Setting the name back to its original name + resp = self.get_resp(url.format(tbl_id, slice_id, slice_name, 'save')) + assert slice_name in resp + + # Doing a basic overwrite assert 'Energy' in self.get_resp( url.format(tbl_id, slice_id, copy_name, 'overwrite')) @@ -281,15 +287,15 @@ def test_gamma(self): assert "List Dashboard" in self.get_resp('/dashboardmodelview/list/') def test_csv_endpoint(self): + self.login('admin') sql = """ SELECT first_name, last_name FROM ab_user WHERE first_name='admin' """ client_id = "{}".format(random.getrandbits(64))[:10] - self.run_sql(sql, 'admin', client_id) + self.run_sql(sql, client_id) - self.login('admin') resp = self.get_resp('/superset/csv/{}'.format(client_id)) data = csv.reader(io.StringIO(resp)) expected_data = csv.reader( @@ -299,36 +305,48 @@ def test_csv_endpoint(self): self.logout() def test_public_user_dashboard_access(self): + table = ( + db.session + .query(models.SqlaTable) + .filter_by(table_name='birth_names') + .one() + ) # Try access before adding appropriate permissions. - self.revoke_public_access('birth_names') + self.revoke_public_access_to_table(table) self.logout() resp = self.get_resp('/slicemodelview/list/') - assert 'birth_names' not in resp + self.assertNotIn('birth_names', resp) resp = self.get_resp('/dashboardmodelview/list/') - assert '/superset/dashboard/births/' not in resp + self.assertNotIn('/superset/dashboard/births/', resp) - self.setup_public_access_for_dashboard('birth_names') + self.grant_public_access_to_table(table) # Try access after adding appropriate permissions. - assert 'birth_names' in self.get_resp('/slicemodelview/list/') + self.assertIn('birth_names', self.get_resp('/slicemodelview/list/')) resp = self.get_resp('/dashboardmodelview/list/') - assert "/superset/dashboard/births/" in resp + self.assertIn("/superset/dashboard/births/", resp) - assert 'Births' in self.get_resp('/superset/dashboard/births/') + self.assertIn('Births', self.get_resp('/superset/dashboard/births/')) # Confirm that public doesn't have access to other datasets. resp = self.get_resp('/slicemodelview/list/') - assert 'wb_health_population' not in resp + self.assertNotIn('wb_health_population', resp) resp = self.get_resp('/dashboardmodelview/list/') - assert "/superset/dashboard/world_health/" not in resp + self.assertNotIn("/superset/dashboard/world_health/", resp) def test_dashboard_with_created_by_can_be_accessed_by_public_users(self): self.logout() - self.setup_public_access_for_dashboard('birth_names') + table = ( + db.session + .query(models.SqlaTable) + .filter_by(table_name='birth_names') + .one() + ) + self.grant_public_access_to_table(table) dash = db.session.query(models.Dashboard).filter_by(dashboard_title="Births").first() dash.owners = [appbuilder.sm.find_user('admin')] @@ -382,8 +400,9 @@ def test_process_template(self): self.assertEqual("SELECT '2017-01-01T00:00:00'", rendered) def test_templated_sql_json(self): + self.login('admin') sql = "SELECT '{{ datetime(2017, 1, 1).isoformat() }}' as test" - data = self.run_sql(sql, "admin", "fdaklj3ws") + data = self.run_sql(sql, "fdaklj3ws") self.assertEqual(data['data'][0]['test'], "2017-01-01T00:00:00") def test_table_metadata(self): diff --git a/tests/druid_tests.py b/tests/druid_tests.py index ea1d7b116251..96a63fb1b1fd 100644 --- a/tests/druid_tests.py +++ b/tests/druid_tests.py @@ -241,8 +241,8 @@ def test_filter_druid_datasource(self): no_gamma_ds.cluster = cluster db.session.merge(no_gamma_ds) - utils.merge_perm(sm, 'datasource_access', gamma_ds.perm) - utils.merge_perm(sm, 'datasource_access', no_gamma_ds.perm) + sm.add_permission_view_menu('datasource_access', gamma_ds.perm) + sm.add_permission_view_menu('datasource_access', no_gamma_ds.perm) db.session.commit() diff --git a/tests/sqllab_tests.py b/tests/sqllab_tests.py index a5c5dbc8b28d..cd8a93d806bc 100644 --- a/tests/sqllab_tests.py +++ b/tests/sqllab_tests.py @@ -4,9 +4,8 @@ from __future__ import print_function from __future__ import unicode_literals -import csv +from datetime import datetime, timedelta import json -import io import unittest from flask_appbuilder.security.sqla import models as ab_models @@ -20,25 +19,41 @@ class SqlLabTests(SupersetTestCase): def __init__(self, *args, **kwargs): super(SqlLabTests, self).__init__(*args, **kwargs) - def setUp(self): + def run_some_queries(self): + self.logout() db.session.query(models.Query).delete() - self.run_sql("SELECT * FROM ab_user", 'admin', client_id='client_id_1') - self.run_sql("SELECT * FROM NO_TABLE", 'admin', client_id='client_id_3') - self.run_sql("SELECT * FROM ab_permission", 'gamma', client_id='client_id_2') + db.session.commit() + self.run_sql( + "SELECT * FROM ab_user", + client_id='client_id_1', + user_name='admin') + self.run_sql( + "SELECT * FROM NO_TABLE", + client_id='client_id_3', + user_name='admin') + self.run_sql( + "SELECT * FROM ab_permission", + client_id='client_id_2', + user_name='gamma_sqllab') + self.logout() def tearDown(self): db.session.query(models.Query).delete() + db.session.commit() + self.logout() def test_sql_json(self): - data = self.run_sql('SELECT * FROM ab_user', 'admin', "1") + self.login('admin') + + data = self.run_sql('SELECT * FROM ab_user', "1") self.assertLess(0, len(data['data'])) - data = self.run_sql('SELECT * FROM unexistant_table', 'admin', "2") + data = self.run_sql('SELECT * FROM unexistant_table', "2") self.assertLess(0, len(data['error'])) def test_sql_json_has_access(self): main_db = self.get_main_database(db.session) - utils.merge_perm(sm, 'database_access', main_db.perm) + sm.add_permission_view_menu('database_access', main_db.perm) db.session.commit() main_db_permission_view = ( db.session.query(ab_models.PermissionView) @@ -48,119 +63,133 @@ def test_sql_json_has_access(self): ) astronaut = sm.add_role("Astronaut") sm.add_permission_role(astronaut, main_db_permission_view) - # Astronaut role is Gamma + main db permissions - for gamma_perm in sm.find_role('Gamma').permissions: - sm.add_permission_role(astronaut, gamma_perm) + # Astronaut role is Gamma + sqllab + main db permissions + for perm in sm.find_role('Gamma').permissions: + sm.add_permission_role(astronaut, perm) + for perm in sm.find_role('sql_lab').permissions: + sm.add_permission_role(astronaut, perm) gagarin = appbuilder.sm.find_user('gagarin') if not gagarin: appbuilder.sm.add_user( 'gagarin', 'Iurii', 'Gagarin', 'gagarin@cosmos.ussr', - appbuilder.sm.find_role('Astronaut'), + astronaut, password='general') - data = self.run_sql('SELECT * FROM ab_user', 'gagarin', "3") + data = self.run_sql('SELECT * FROM ab_user', "3", user_name='gagarin') db.session.query(models.Query).delete() db.session.commit() self.assertLess(0, len(data['data'])) def test_queries_endpoint(self): - resp = self.client.get('/superset/queries/{}'.format(0)) + self.run_some_queries() + + # Not logged in, should error out + resp = self.client.get('/superset/queries/0') self.assertEquals(403, resp.status_code) + # Admin sees queries self.login('admin') - data = self.get_json_resp('/superset/queries/{}'.format(0)) + data = self.get_json_resp('/superset/queries/0') self.assertEquals(2, len(data)) - self.logout() - self.run_sql("SELECT * FROM ab_user1", 'admin', client_id='client_id_4') - self.run_sql("SELECT * FROM ab_user2", 'admin', client_id='client_id_5') + # Run 2 more queries + self.run_sql("SELECT * FROM ab_user1", client_id='client_id_4') + self.run_sql("SELECT * FROM ab_user2", client_id='client_id_5') self.login('admin') - data = self.get_json_resp('/superset/queries/{}'.format(0)) + data = self.get_json_resp('/superset/queries/0') self.assertEquals(4, len(data)) + now = datetime.now() + timedelta(days=1) query = db.session.query(models.Query).filter_by( sql='SELECT * FROM ab_user1').first() - query.changed_on = utils.EPOCH + query.changed_on = now db.session.commit() - data = self.get_json_resp('/superset/queries/{}'.format(123456000)) - self.assertEquals(3, len(data)) + data = self.get_json_resp( + '/superset/queries/{}'.format( + int(utils.datetime_to_epoch(now))-1000)) + self.assertEquals(1, len(data)) self.logout() - resp = self.client.get('/superset/queries/{}'.format(0)) + resp = self.client.get('/superset/queries/0') self.assertEquals(403, resp.status_code) def test_search_query_on_db_id(self): - self.login('admin') - # Test search queries on database Id - resp = self.get_resp('/superset/search_queries?database_id=1') - data = json.loads(resp) - self.assertEquals(3, len(data)) - db_ids = [data[k]['dbId'] for k in data] - self.assertEquals([1, 1, 1], db_ids) - - resp = self.get_resp('/superset/search_queries?database_id=-1') - data = json.loads(resp) - self.assertEquals(0, len(data)) - self.logout() + self.run_some_queries() + self.login('admin') + # Test search queries on database Id + resp = self.get_resp('/superset/search_queries?database_id=1') + data = json.loads(resp) + self.assertEquals(3, len(data)) + db_ids = [data[k]['dbId'] for k in data] + self.assertEquals([1, 1, 1], db_ids) + + resp = self.get_resp('/superset/search_queries?database_id=-1') + data = json.loads(resp) + self.assertEquals(0, len(data)) def test_search_query_on_user(self): - self.login('admin') - # Test search queries on user Id - user = appbuilder.sm.find_user('admin') - resp = self.get_resp('/superset/search_queries?user_id={}'.format(user.id)) - data = json.loads(resp) - self.assertEquals(2, len(data)) - user_ids = [data[k]['userId'] for k in data] - self.assertEquals([user.id, user.id], user_ids) - - user = appbuilder.sm.find_user('gamma') - resp = self.get_resp('/superset/search_queries?user_id={}'.format(user.id)) - data = json.loads(resp) - self.assertEquals(1, len(data)) - self.assertEquals(list(data.values())[0]['userId'] , user.id) - self.logout() + self.run_some_queries() + self.login('admin') + + # Test search queries on user Id + user = appbuilder.sm.find_user('admin') + data = self.get_json_resp( + '/superset/search_queries?user_id={}'.format(user.id)) + self.assertEquals(2, len(data)) + user_ids = {data[k]['userId'] for k in data} + self.assertEquals(set([user.id]), user_ids) + + user = appbuilder.sm.find_user('gamma_sqllab') + resp = self.get_resp('/superset/search_queries?user_id={}'.format(user.id)) + data = json.loads(resp) + self.assertEquals(1, len(data)) + self.assertEquals(list(data.values())[0]['userId'] , user.id) def test_search_query_on_status(self): - self.login('admin') - # Test search queries on status - resp = self.get_resp('/superset/search_queries?status=success') - data = json.loads(resp) - self.assertEquals(2, len(data)) - states = [data[k]['state'] for k in data] - self.assertEquals(['success', 'success'], states) - - resp = self.get_resp('/superset/search_queries?status=failed') - data = json.loads(resp) - self.assertEquals(1, len(data)) - self.assertEquals(list(data.values())[0]['state'], 'failed') - self.logout() + self.run_some_queries() + self.login('admin') + # Test search queries on status + resp = self.get_resp('/superset/search_queries?status=success') + data = json.loads(resp) + self.assertEquals(2, len(data)) + states = [data[k]['state'] for k in data] + self.assertEquals(['success', 'success'], states) + + resp = self.get_resp('/superset/search_queries?status=failed') + data = json.loads(resp) + self.assertEquals(1, len(data)) + self.assertEquals(list(data.values())[0]['state'], 'failed') def test_search_query_on_text(self): - self.login('admin') - resp = self.get_resp('/superset/search_queries?search_text=permission') - data = json.loads(resp) - self.assertEquals(1, len(data)) - self.assertIn('permission', list(data.values())[0]['sql']) - self.logout() + self.run_some_queries() + self.login('admin') + resp = self.get_resp('/superset/search_queries?search_text=permission') + data = json.loads(resp) + self.assertEquals(1, len(data)) + self.assertIn('permission', list(data.values())[0]['sql']) def test_search_query_on_time(self): - self.login('admin') - first_query_time = db.session.query(models.Query).filter_by( - sql='SELECT * FROM ab_user').first().start_time - second_query_time = db.session.query(models.Query).filter_by( - sql='SELECT * FROM ab_permission').first().start_time - # Test search queries on time filter - from_time = 'from={}'.format(int(first_query_time)) - to_time = 'to={}'.format(int(second_query_time)) - params = [from_time, to_time] - resp = self.get_resp('/superset/search_queries?'+'&'.join(params)) - data = json.loads(resp) - self.assertEquals(2, len(data)) - for _, v in data.items(): - self.assertLess(int(first_query_time), v['startDttm']) - self.assertLess(v['startDttm'], int(second_query_time)) - self.logout() + self.run_some_queries() + self.login('admin') + first_query_time = ( + db.session.query(models.Query) + .filter_by(sql='SELECT * FROM ab_user').one() + ).start_time + second_query_time = ( + db.session.query(models.Query) + .filter_by(sql='SELECT * FROM ab_permission').one() + ).start_time + # Test search queries on time filter + from_time = 'from={}'.format(int(first_query_time)) + to_time = 'to={}'.format(int(second_query_time)) + params = [from_time, to_time] + resp = self.get_resp('/superset/search_queries?'+'&'.join(params)) + data = json.loads(resp) + self.assertEquals(2, len(data)) + for _, v in data.items(): + self.assertLess(int(first_query_time), v['startDttm']) + self.assertLess(v['startDttm'], int(second_query_time)) if __name__ == '__main__':