Skip to content

Commit

Permalink
feat(rbac): add customizable related filters (#22526)
Browse files Browse the repository at this point in the history
  • Loading branch information
villebro committed Jan 5, 2023
1 parent b352947 commit 037deb9
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 4 deletions.
28 changes: 28 additions & 0 deletions superset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
Tuple,
Type,
TYPE_CHECKING,
TypedDict,
Union,
)

Expand All @@ -54,6 +55,7 @@
from flask import Blueprint
from flask_appbuilder.security.manager import AUTH_DB
from pandas._libs.parsers import STR_NA_VALUES # pylint: disable=no-name-in-module
from sqlalchemy.orm.query import Query

from superset.advanced_data_type.plugins.internet_address import internet_address
from superset.advanced_data_type.plugins.internet_port import internet_port
Expand Down Expand Up @@ -1502,6 +1504,32 @@ def EMAIL_HEADER_MUTATOR( # pylint: disable=invalid-name,unused-argument
},
}


# Extra related query filters make it possible to limit which objects are shown
# in the UI. For examples, to only show "admin" or users starting with the letter "b" in
# the "Owners" dropdowns, you could add the following in your config:
# def user_filter(query: Query, *args, *kwargs):
# from superset import security_manager
#
# user_model = security_manager.user_model
# filters = [
# user_model.username == "admin",
# user_model.username.ilike("b%"),
# ]
# return query.filter(or_(*filters))
#
# EXTRA_RELATED_QUERY_FILTERS = {"user": user_filter}
#
# Similarly, to restrict the roles in the "Roles" dropdown you can provide a custom
# filter callback for the "role" key.
class ExtraRelatedQueryFilters(TypedDict, total=False):
role: Callable[[Query], Query]
user: Callable[[Query], Query]


EXTRA_RELATED_QUERY_FILTERS: ExtraRelatedQueryFilters = {}


# -------------------------------------------------------------------
# * WARNING: STOP EDITING HERE *
# -------------------------------------------------------------------
Expand Down
8 changes: 7 additions & 1 deletion superset/dashboards/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,11 @@
requires_json,
statsd_metrics,
)
from superset.views.filters import BaseFilterRelatedUsers, FilterRelatedOwners
from superset.views.filters import (
BaseFilterRelatedRoles,
BaseFilterRelatedUsers,
FilterRelatedOwners,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -244,7 +248,9 @@ def ensure_thumbnails_enabled(self) -> Optional[Response]:
base_related_field_filters = {
"owners": [["id", BaseFilterRelatedUsers, lambda: []]],
"created_by": [["id", BaseFilterRelatedUsers, lambda: []]],
"roles": [["id", BaseFilterRelatedRoles, lambda: []]],
}

related_field_filters = {
"owners": RelatedFieldFilter("first_name", FilterRelatedOwners),
"roles": RelatedFieldFilter("name", FilterRelatedRoles),
Expand Down
30 changes: 27 additions & 3 deletions superset/views/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,35 @@ class BaseFilterRelatedUsers(BaseFilter): # pylint: disable=too-few-public-meth
arg_name = "username"

def apply(self, query: Query, value: Optional[Any]) -> Query:
user_model = security_manager.user_model
if extra_filters := current_app.config["EXTRA_RELATED_QUERY_FILTERS"].get(
"user",
):
query = extra_filters(query)

exclude_users = (
security_manager.get_exclude_users_from_lists()
if current_app.config["EXCLUDE_USERS_FROM_LISTS"] is None
else current_app.config["EXCLUDE_USERS_FROM_LISTS"]
)
query_ = query.filter(and_(user_model.username.not_in(exclude_users)))
return query_
if exclude_users:
user_model = security_manager.user_model
return query.filter(and_(user_model.username.not_in(exclude_users)))

return query


class BaseFilterRelatedRoles(BaseFilter): # pylint: disable=too-few-public-methods
"""
Filter to apply on related roles.
"""

name = lazy_gettext("role")
arg_name = "role"

def apply(self, query: Query, value: Optional[Any]) -> Query:
if extra_filters := current_app.config["EXTRA_RELATED_QUERY_FILTERS"].get(
"role",
):
return extra_filters(query)

return query
20 changes: 20 additions & 0 deletions tests/integration_tests/base_api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,26 @@ def test_get_related_owners(self):
for expected_user in expected_users:
assert expected_user in response_users

def test_get_related_owners_with_extra_filters(self):
"""
API: Test get related owners with extra related query filters
"""
self.login(username="admin")

def _base_filter(query):
return query.filter_by(username="alpha")

with patch.dict(
"superset.views.filters.current_app.config",
{"EXTRA_RELATED_QUERY_FILTERS": {"user": _base_filter}},
):
uri = f"api/v1/{self.resource_name}/related/owners"
rv = self.client.get(uri)
assert rv.status_code == 200
response = json.loads(rv.data.decode("utf-8"))
response_users = [result["text"] for result in response["result"]]
assert response_users == ["alpha user"]

def test_get_related_owners_paginated(self):
"""
API: Test get related owners with pagination
Expand Down
20 changes: 20 additions & 0 deletions tests/integration_tests/dashboards/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1828,6 +1828,26 @@ def test_get_filter_related_roles(self):
response_roles = [result["text"] for result in response["result"]]
assert "Alpha" in response_roles

def test_get_all_related_roles_with_with_extra_filters(self):
"""
API: Test get filter related roles with extra related query filters
"""
self.login(username="admin")

def _base_filter(query):
return query.filter_by(name="Alpha")

with patch.dict(
"superset.views.filters.current_app.config",
{"EXTRA_RELATED_QUERY_FILTERS": {"role": _base_filter}},
):
uri = f"api/v1/dashboard/related/roles"
rv = self.client.get(uri)
assert rv.status_code == 200
response = json.loads(rv.data.decode("utf-8"))
response_roles = [result["text"] for result in response["result"]]
assert response_roles == ["Alpha"]

@pytest.mark.usefixtures("load_world_bank_dashboard_with_slices")
def test_embedded_dashboards(self):
self.login(username="admin")
Expand Down

0 comments on commit 037deb9

Please sign in to comment.