Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
Revert "fix: datasource payload is incorrect (apache#15184)"
Browse files Browse the repository at this point in the history
This reverts commit 216e2b8.
  • Loading branch information
serenajiang committed Jun 23, 2021
1 parent b295c6a commit 658ab0f
Show file tree
Hide file tree
Showing 4 changed files with 2 additions and 117 deletions.
36 changes: 0 additions & 36 deletions superset/connectors/connector_registry.py
Expand Up @@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from collections import defaultdict
from typing import Dict, List, Optional, Set, Type, TYPE_CHECKING

from flask_babel import _
Expand Down Expand Up @@ -100,41 +99,6 @@ def get_datasource_by_id( # pylint: disable=too-many-arguments
pass
raise NoResultFound(_("Datasource id not found: %(id)s", id=datasource_id))

@classmethod
def get_user_datasources(cls, session: Session) -> List["BaseDatasource"]:
from superset import security_manager

# collect datasources which the user has explicit permissions to
user_perms = security_manager.user_view_menu_names("datasource_access")
schema_perms = security_manager.user_view_menu_names("schema_access")
user_datasources = set()
for datasource_class in ConnectorRegistry.sources.values():
user_datasources.update(
session.query(datasource_class)
.filter(
or_(
datasource_class.perm.in_(user_perms),
datasource_class.schema_perm.in_(schema_perms),
)
)
.all()
)

# group all datasources by database
all_datasources = cls.get_all_datasources(session)
datasources_by_database: Dict["Database", Set["BaseDatasource"]] = defaultdict(
set
)
for datasource in all_datasources:
datasources_by_database[datasource.database].add(datasource)

# add datasources with implicit permission (eg, database access)
for database, datasources in datasources_by_database.items():
if security_manager.can_access_database(database):
user_datasources.update(datasources)

return list(user_datasources)

@classmethod
def get_datasource_by_name( # pylint: disable=too-many-arguments
cls,
Expand Down
2 changes: 1 addition & 1 deletion superset/views/chart/views.py
Expand Up @@ -65,7 +65,7 @@ def pre_delete(self, item: "SliceModelView") -> None:
def add(self) -> FlaskResponse:
datasources = [
{"value": str(d.id) + "__" + d.type, "label": repr(d)}
for d in ConnectorRegistry.get_user_datasources(db.session)
for d in ConnectorRegistry.get_all_datasources(db.session)
]
payload = {
"datasources": sorted(
Expand Down
2 changes: 1 addition & 1 deletion superset/views/core.py
Expand Up @@ -185,7 +185,7 @@ def datasources(self) -> FlaskResponse:
sorted(
[
datasource.short_data
for datasource in ConnectorRegistry.get_user_datasources(db.session)
for datasource in ConnectorRegistry.get_all_datasources(db.session)
if datasource.short_data.get("name")
],
key=lambda datasource: datasource["name"],
Expand Down
79 changes: 0 additions & 79 deletions tests/access_tests.py
Expand Up @@ -18,7 +18,6 @@
"""Unit tests for Superset"""
import json
import unittest
from collections import namedtuple
from unittest import mock
from tests.fixtures.birth_names_dashboard import load_birth_names_dashboard_with_slices

Expand Down Expand Up @@ -627,83 +626,5 @@ def test_request_access(self):
session.commit()


class TestDatasources(SupersetTestCase):
def test_get_user_datasources_admin(self):
Datasource = namedtuple("Datasource", ["database", "schema", "name"])

mock_session = mock.MagicMock()
mock_session.query.return_value.filter.return_value.all.return_value = []

with mock.patch("superset.security_manager") as mock_security_manager:
mock_security_manager.can_access_database.return_value = True

with mock.patch.object(
ConnectorRegistry, "get_all_datasources"
) as mock_get_all_datasources:
mock_get_all_datasources.return_value = [
Datasource("database1", "schema1", "table1"),
Datasource("database1", "schema1", "table2"),
Datasource("database2", None, "table1"),
]

datasources = ConnectorRegistry.get_user_datasources(mock_session)

assert sorted(datasources) == [
Datasource("database1", "schema1", "table1"),
Datasource("database1", "schema1", "table2"),
Datasource("database2", None, "table1"),
]

def test_get_user_datasources_gamma(self):
Datasource = namedtuple("Datasource", ["database", "schema", "name"])

mock_session = mock.MagicMock()
mock_session.query.return_value.filter.return_value.all.return_value = []

with mock.patch("superset.security_manager") as mock_security_manager:
mock_security_manager.can_access_database.return_value = False

with mock.patch.object(
ConnectorRegistry, "get_all_datasources"
) as mock_get_all_datasources:
mock_get_all_datasources.return_value = [
Datasource("database1", "schema1", "table1"),
Datasource("database1", "schema1", "table2"),
Datasource("database2", None, "table1"),
]

datasources = ConnectorRegistry.get_user_datasources(mock_session)

assert datasources == []

def test_get_user_datasources_gamma_with_schema(self):
Datasource = namedtuple("Datasource", ["database", "schema", "name"])

mock_session = mock.MagicMock()
mock_session.query.return_value.filter.return_value.all.return_value = [
Datasource("database1", "schema1", "table1"),
Datasource("database1", "schema1", "table2"),
]

with mock.patch("superset.security_manager") as mock_security_manager:
mock_security_manager.can_access_database.return_value = False

with mock.patch.object(
ConnectorRegistry, "get_all_datasources"
) as mock_get_all_datasources:
mock_get_all_datasources.return_value = [
Datasource("database1", "schema1", "table1"),
Datasource("database1", "schema1", "table2"),
Datasource("database2", None, "table1"),
]

datasources = ConnectorRegistry.get_user_datasources(mock_session)

assert sorted(datasources) == [
Datasource("database1", "schema1", "table1"),
Datasource("database1", "schema1", "table2"),
]


if __name__ == "__main__":
unittest.main()

0 comments on commit 658ab0f

Please sign in to comment.