Skip to content

Commit

Permalink
feat: datasource access to allow more granular access to tables on SQ…
Browse files Browse the repository at this point in the history
…L Lab (#18064)
  • Loading branch information
Victor Arbues committed Feb 9, 2022
1 parent fdbcbb5 commit 5ee070c
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 15 deletions.
19 changes: 13 additions & 6 deletions superset/databases/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,28 @@

class DatabaseFilter(BaseFilter):
# TODO(bogdan): consider caching.
def schema_access_databases(self) -> Set[str]: # noqa pylint: disable=no-self-use

def can_access_databases( # noqa pylint: disable=no-self-use
self, view_menu_name: str,
) -> Set[str]:
return {
security_manager.unpack_schema_perm(vm)[0]
for vm in security_manager.user_view_menu_names("schema_access")
security_manager.unpack_database_and_schema(vm).database
for vm in security_manager.user_view_menu_names(view_menu_name)
}

def apply(self, query: Query, value: Any) -> Query:
if security_manager.can_access_all_databases():
return query
database_perms = security_manager.user_view_menu_names("database_access")
# TODO(bogdan): consider adding datasource access here as well.
schema_access_databases = self.schema_access_databases()
schema_access_databases = self.can_access_databases("schema_access")

datasource_access_databases = self.can_access_databases("datasource_access")

return query.filter(
or_(
self.model.perm.in_(database_perms),
self.model.database_name.in_(schema_access_databases),
self.model.database_name.in_(
[*schema_access_databases, *datasource_access_databases]
),
)
)
20 changes: 13 additions & 7 deletions superset/security/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
cast,
Dict,
List,
NamedTuple,
Optional,
Set,
Tuple,
TYPE_CHECKING,
Union,
)
Expand Down Expand Up @@ -88,6 +88,11 @@
logger = logging.getLogger(__name__)


class DatabaseAndSchema(NamedTuple):
database: str
schema: str


class SupersetSecurityListWidget(ListWidget): # pylint: disable=too-few-public-methods
"""
Redeclaring to avoid circular imports
Expand Down Expand Up @@ -263,13 +268,14 @@ def get_schema_perm( # pylint: disable=no-self-use

return None

def unpack_schema_perm( # pylint: disable=no-self-use
def unpack_database_and_schema( # pylint: disable=no-self-use
self, schema_permission: str
) -> Tuple[str, str]:
# [database_name].[schema_name]
) -> DatabaseAndSchema:
# [database_name].[schema|table]

schema_name = schema_permission.split(".")[1][1:-1]
database_name = schema_permission.split(".")[0][1:-1]
return database_name, schema_name
return DatabaseAndSchema(database_name, schema_name)

def can_access(self, permission_name: str, view_name: str) -> bool:
"""
Expand Down Expand Up @@ -558,7 +564,7 @@ def get_schemas_accessible_by_user(

# schema_access
accessible_schemas = {
self.unpack_schema_perm(s)[1]
self.unpack_database_and_schema(s).schema
for s in self.user_view_menu_names("schema_access")
if s.startswith(f"[{database}].")
}
Expand Down Expand Up @@ -608,7 +614,7 @@ def get_datasources_accessible_by_user( # pylint: disable=invalid-name
)
if schema:
names = {d.table_name for d in user_datasources if d.schema == schema}
return [d for d in datasource_names if d in names]
return [d for d in datasource_names if d.table in names]

full_names = {d.full_name for d in user_datasources}
return [d for d in datasource_names if f"[{database}].[{d}]" in full_names]
Expand Down
59 changes: 59 additions & 0 deletions tests/integration_tests/core_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,65 @@ def test_get_superset_tables_not_allowed(self):
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)

@pytest.mark.usefixtures("load_energy_table_with_slice")
def test_get_superset_tables_allowed(self):
session = db.session
table_name = "energy_usage"
role_name = "dummy_role"
self.logout()
self.login(username="gamma")
gamma_user = security_manager.find_user(username="gamma")
security_manager.add_role(role_name)
dummy_role = security_manager.find_role(role_name)
gamma_user.roles.append(dummy_role)

tbl_id = self.table_ids.get(table_name)
table = db.session.query(SqlaTable).filter(SqlaTable.id == tbl_id).first()
table_perm = table.perm

security_manager.add_permission_role(
dummy_role,
security_manager.find_permission_view_menu("datasource_access", table_perm),
)

session.commit()

example_db = utils.get_example_database()
schema_name = self.default_schema_backend_map[example_db.backend]
uri = f"superset/tables/{example_db.id}/{schema_name}/{table_name}/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)

# cleanup
gamma_user = security_manager.find_user(username="gamma")
gamma_user.roles.remove(security_manager.find_role(role_name))
session.commit()

@pytest.mark.usefixtures("load_energy_table_with_slice")
def test_get_superset_tables_not_allowed_with_out_permissions(self):
session = db.session
table_name = "energy_usage"
role_name = "dummy_role_no_table_access"
self.logout()
self.login(username="gamma")
gamma_user = security_manager.find_user(username="gamma")
security_manager.add_role(role_name)
dummy_role = security_manager.find_role(role_name)
gamma_user.roles.append(dummy_role)

session.commit()

example_db = utils.get_example_database()
schema_name = self.default_schema_backend_map[example_db.backend]
uri = f"superset/tables/{example_db.id}/{schema_name}/{table_name}/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)

# cleanup
gamma_user = security_manager.find_user(username="gamma")
gamma_user.roles.remove(security_manager.find_role(role_name))
session.commit()

def test_get_superset_tables_substr(self):
example_db = superset.utils.database.get_example_database()
if example_db.backend in {"presto", "hive"}:
Expand Down
6 changes: 4 additions & 2 deletions tests/integration_tests/datasets/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,10 @@ def test_get_dataset_related_database_gamma(self):
rv = self.client.get(uri)
assert rv.status_code == 200
response = json.loads(rv.data.decode("utf-8"))
assert response["count"] == 0
assert response["result"] == []

assert response["count"] == 1
main_db = get_main_database()
assert filter(lambda x: x.text == main_db, response["result"]) != []

@pytest.mark.usefixtures("load_energy_table_with_slice")
def test_get_dataset_item(self):
Expand Down

0 comments on commit 5ee070c

Please sign in to comment.