From 5ee070c40228d6abbb30e4a8f7888886cf35d7f1 Mon Sep 17 00:00:00 2001 From: Victor Arbues Date: Wed, 9 Feb 2022 14:05:25 +0000 Subject: [PATCH] feat: datasource access to allow more granular access to tables on SQL Lab (#18064) --- superset/databases/filters.py | 19 ++++-- superset/security/manager.py | 20 ++++--- tests/integration_tests/core_tests.py | 59 +++++++++++++++++++ tests/integration_tests/datasets/api_tests.py | 6 +- 4 files changed, 89 insertions(+), 15 deletions(-) diff --git a/superset/databases/filters.py b/superset/databases/filters.py index 6fa9339c07f2..bee7d2c7b213 100644 --- a/superset/databases/filters.py +++ b/superset/databases/filters.py @@ -25,21 +25,28 @@ class DatabaseFilter(BaseFilter): # TODO(bogdan): consider caching. - def schema_access_databases(self) -> Set[str]: # noqa pylint: disable=no-self-use + + def can_access_databases( # noqa pylint: disable=no-self-use + self, view_menu_name: str, + ) -> Set[str]: return { - security_manager.unpack_schema_perm(vm)[0] - for vm in security_manager.user_view_menu_names("schema_access") + security_manager.unpack_database_and_schema(vm).database + for vm in security_manager.user_view_menu_names(view_menu_name) } def apply(self, query: Query, value: Any) -> Query: if security_manager.can_access_all_databases(): return query database_perms = security_manager.user_view_menu_names("database_access") - # TODO(bogdan): consider adding datasource access here as well. - schema_access_databases = self.schema_access_databases() + schema_access_databases = self.can_access_databases("schema_access") + + datasource_access_databases = self.can_access_databases("datasource_access") + return query.filter( or_( self.model.perm.in_(database_perms), - self.model.database_name.in_(schema_access_databases), + self.model.database_name.in_( + [*schema_access_databases, *datasource_access_databases] + ), ) ) diff --git a/superset/security/manager.py b/superset/security/manager.py index d9206dfe4d37..0bed4476e526 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -26,9 +26,9 @@ cast, Dict, List, + NamedTuple, Optional, Set, - Tuple, TYPE_CHECKING, Union, ) @@ -88,6 +88,11 @@ logger = logging.getLogger(__name__) +class DatabaseAndSchema(NamedTuple): + database: str + schema: str + + class SupersetSecurityListWidget(ListWidget): # pylint: disable=too-few-public-methods """ Redeclaring to avoid circular imports @@ -263,13 +268,14 @@ def get_schema_perm( # pylint: disable=no-self-use return None - def unpack_schema_perm( # pylint: disable=no-self-use + def unpack_database_and_schema( # pylint: disable=no-self-use self, schema_permission: str - ) -> Tuple[str, str]: - # [database_name].[schema_name] + ) -> DatabaseAndSchema: + # [database_name].[schema|table] + schema_name = schema_permission.split(".")[1][1:-1] database_name = schema_permission.split(".")[0][1:-1] - return database_name, schema_name + return DatabaseAndSchema(database_name, schema_name) def can_access(self, permission_name: str, view_name: str) -> bool: """ @@ -558,7 +564,7 @@ def get_schemas_accessible_by_user( # schema_access accessible_schemas = { - self.unpack_schema_perm(s)[1] + self.unpack_database_and_schema(s).schema for s in self.user_view_menu_names("schema_access") if s.startswith(f"[{database}].") } @@ -608,7 +614,7 @@ def get_datasources_accessible_by_user( # pylint: disable=invalid-name ) if schema: names = {d.table_name for d in user_datasources if d.schema == schema} - return [d for d in datasource_names if d in names] + return [d for d in datasource_names if d.table in names] full_names = {d.full_name for d in user_datasources} return [d for d in datasource_names if f"[{database}].[{d}]" in full_names] diff --git a/tests/integration_tests/core_tests.py b/tests/integration_tests/core_tests.py index 1c4682ad9a9b..43288e08ae64 100644 --- a/tests/integration_tests/core_tests.py +++ b/tests/integration_tests/core_tests.py @@ -163,6 +163,65 @@ def test_get_superset_tables_not_allowed(self): rv = self.client.get(uri) self.assertEqual(rv.status_code, 404) + @pytest.mark.usefixtures("load_energy_table_with_slice") + def test_get_superset_tables_allowed(self): + session = db.session + table_name = "energy_usage" + role_name = "dummy_role" + self.logout() + self.login(username="gamma") + gamma_user = security_manager.find_user(username="gamma") + security_manager.add_role(role_name) + dummy_role = security_manager.find_role(role_name) + gamma_user.roles.append(dummy_role) + + tbl_id = self.table_ids.get(table_name) + table = db.session.query(SqlaTable).filter(SqlaTable.id == tbl_id).first() + table_perm = table.perm + + security_manager.add_permission_role( + dummy_role, + security_manager.find_permission_view_menu("datasource_access", table_perm), + ) + + session.commit() + + example_db = utils.get_example_database() + schema_name = self.default_schema_backend_map[example_db.backend] + uri = f"superset/tables/{example_db.id}/{schema_name}/{table_name}/" + rv = self.client.get(uri) + self.assertEqual(rv.status_code, 200) + + # cleanup + gamma_user = security_manager.find_user(username="gamma") + gamma_user.roles.remove(security_manager.find_role(role_name)) + session.commit() + + @pytest.mark.usefixtures("load_energy_table_with_slice") + def test_get_superset_tables_not_allowed_with_out_permissions(self): + session = db.session + table_name = "energy_usage" + role_name = "dummy_role_no_table_access" + self.logout() + self.login(username="gamma") + gamma_user = security_manager.find_user(username="gamma") + security_manager.add_role(role_name) + dummy_role = security_manager.find_role(role_name) + gamma_user.roles.append(dummy_role) + + session.commit() + + example_db = utils.get_example_database() + schema_name = self.default_schema_backend_map[example_db.backend] + uri = f"superset/tables/{example_db.id}/{schema_name}/{table_name}/" + rv = self.client.get(uri) + self.assertEqual(rv.status_code, 404) + + # cleanup + gamma_user = security_manager.find_user(username="gamma") + gamma_user.roles.remove(security_manager.find_role(role_name)) + session.commit() + def test_get_superset_tables_substr(self): example_db = superset.utils.database.get_example_database() if example_db.backend in {"presto", "hive"}: diff --git a/tests/integration_tests/datasets/api_tests.py b/tests/integration_tests/datasets/api_tests.py index 0e7606c45ed5..aaf76338ef01 100644 --- a/tests/integration_tests/datasets/api_tests.py +++ b/tests/integration_tests/datasets/api_tests.py @@ -219,8 +219,10 @@ def test_get_dataset_related_database_gamma(self): rv = self.client.get(uri) assert rv.status_code == 200 response = json.loads(rv.data.decode("utf-8")) - assert response["count"] == 0 - assert response["result"] == [] + + assert response["count"] == 1 + main_db = get_main_database() + assert filter(lambda x: x.text == main_db, response["result"]) != [] @pytest.mark.usefixtures("load_energy_table_with_slice") def test_get_dataset_item(self):