From d916852557381b9bd0d8f42923de191ef45b0d87 Mon Sep 17 00:00:00 2001 From: John Bodley Date: Wed, 30 Jun 2021 16:12:06 -0700 Subject: [PATCH] Revert "Revert "fix: datasource payload is incorrect (#15184)"" This reverts commit f230b19893687c7fed7bbb62f210f4ed576b93ef. --- superset/connectors/connector_registry.py | 36 +++++++++++ superset/views/chart/views.py | 2 +- superset/views/core.py | 2 +- tests/access_tests.py | 79 +++++++++++++++++++++++ 4 files changed, 117 insertions(+), 2 deletions(-) diff --git a/superset/connectors/connector_registry.py b/superset/connectors/connector_registry.py index 86081cd1b539..7de19b3484df 100644 --- a/superset/connectors/connector_registry.py +++ b/superset/connectors/connector_registry.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from collections import defaultdict from typing import Dict, List, Optional, Set, Type, TYPE_CHECKING from flask_babel import _ @@ -99,6 +100,41 @@ def get_datasource_by_id( # pylint: disable=too-many-arguments pass raise NoResultFound(_("Datasource id not found: %(id)s", id=datasource_id)) + @classmethod + def get_user_datasources(cls, session: Session) -> List["BaseDatasource"]: + from superset import security_manager + + # collect datasources which the user has explicit permissions to + user_perms = security_manager.user_view_menu_names("datasource_access") + schema_perms = security_manager.user_view_menu_names("schema_access") + user_datasources = set() + for datasource_class in ConnectorRegistry.sources.values(): + user_datasources.update( + session.query(datasource_class) + .filter( + or_( + datasource_class.perm.in_(user_perms), + datasource_class.schema_perm.in_(schema_perms), + ) + ) + .all() + ) + + # group all datasources by database + all_datasources = cls.get_all_datasources(session) + datasources_by_database: Dict["Database", Set["BaseDatasource"]] = defaultdict( + set + ) + for datasource in all_datasources: + datasources_by_database[datasource.database].add(datasource) + + # add datasources with implicit permission (eg, database access) + for database, datasources in datasources_by_database.items(): + if security_manager.can_access_database(database): + user_datasources.update(datasources) + + return list(user_datasources) + @classmethod def get_datasource_by_name( # pylint: disable=too-many-arguments cls, diff --git a/superset/views/chart/views.py b/superset/views/chart/views.py index e4fedac278d3..45dbd2308caa 100644 --- a/superset/views/chart/views.py +++ b/superset/views/chart/views.py @@ -65,7 +65,7 @@ def pre_delete(self, item: "SliceModelView") -> None: def add(self) -> FlaskResponse: datasources = [ {"value": str(d.id) + "__" + d.type, "label": repr(d)} - for d in ConnectorRegistry.get_all_datasources(db.session) + for d in ConnectorRegistry.get_user_datasources(db.session) ] payload = { "datasources": sorted( diff --git a/superset/views/core.py b/superset/views/core.py index 0e9dd9e2b45f..3be231475c20 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -186,7 +186,7 @@ def datasources(self) -> FlaskResponse: sorted( [ datasource.short_data - for datasource in ConnectorRegistry.get_all_datasources(db.session) + for datasource in ConnectorRegistry.get_user_datasources(db.session) if datasource.short_data.get("name") ], key=lambda datasource: datasource["name"], diff --git a/tests/access_tests.py b/tests/access_tests.py index d3cc55a1c9df..795ca98c30cf 100644 --- a/tests/access_tests.py +++ b/tests/access_tests.py @@ -18,6 +18,7 @@ """Unit tests for Superset""" import json import unittest +from collections import namedtuple from unittest import mock from tests.fixtures.birth_names_dashboard import load_birth_names_dashboard_with_slices @@ -626,5 +627,83 @@ def test_request_access(self): session.commit() +class TestDatasources(SupersetTestCase): + def test_get_user_datasources_admin(self): + Datasource = namedtuple("Datasource", ["database", "schema", "name"]) + + mock_session = mock.MagicMock() + mock_session.query.return_value.filter.return_value.all.return_value = [] + + with mock.patch("superset.security_manager") as mock_security_manager: + mock_security_manager.can_access_database.return_value = True + + with mock.patch.object( + ConnectorRegistry, "get_all_datasources" + ) as mock_get_all_datasources: + mock_get_all_datasources.return_value = [ + Datasource("database1", "schema1", "table1"), + Datasource("database1", "schema1", "table2"), + Datasource("database2", None, "table1"), + ] + + datasources = ConnectorRegistry.get_user_datasources(mock_session) + + assert sorted(datasources) == [ + Datasource("database1", "schema1", "table1"), + Datasource("database1", "schema1", "table2"), + Datasource("database2", None, "table1"), + ] + + def test_get_user_datasources_gamma(self): + Datasource = namedtuple("Datasource", ["database", "schema", "name"]) + + mock_session = mock.MagicMock() + mock_session.query.return_value.filter.return_value.all.return_value = [] + + with mock.patch("superset.security_manager") as mock_security_manager: + mock_security_manager.can_access_database.return_value = False + + with mock.patch.object( + ConnectorRegistry, "get_all_datasources" + ) as mock_get_all_datasources: + mock_get_all_datasources.return_value = [ + Datasource("database1", "schema1", "table1"), + Datasource("database1", "schema1", "table2"), + Datasource("database2", None, "table1"), + ] + + datasources = ConnectorRegistry.get_user_datasources(mock_session) + + assert datasources == [] + + def test_get_user_datasources_gamma_with_schema(self): + Datasource = namedtuple("Datasource", ["database", "schema", "name"]) + + mock_session = mock.MagicMock() + mock_session.query.return_value.filter.return_value.all.return_value = [ + Datasource("database1", "schema1", "table1"), + Datasource("database1", "schema1", "table2"), + ] + + with mock.patch("superset.security_manager") as mock_security_manager: + mock_security_manager.can_access_database.return_value = False + + with mock.patch.object( + ConnectorRegistry, "get_all_datasources" + ) as mock_get_all_datasources: + mock_get_all_datasources.return_value = [ + Datasource("database1", "schema1", "table1"), + Datasource("database1", "schema1", "table2"), + Datasource("database2", None, "table1"), + ] + + datasources = ConnectorRegistry.get_user_datasources(mock_session) + + assert sorted(datasources) == [ + Datasource("database1", "schema1", "table1"), + Datasource("database1", "schema1", "table2"), + ] + + if __name__ == "__main__": unittest.main()