diff --git a/superset/config.py b/superset/config.py index 6760a0af7290..c710b29657c3 100644 --- a/superset/config.py +++ b/superset/config.py @@ -44,6 +44,7 @@ Tuple, Type, TYPE_CHECKING, + TypedDict, Union, ) @@ -54,6 +55,7 @@ from flask import Blueprint from flask_appbuilder.security.manager import AUTH_DB from pandas._libs.parsers import STR_NA_VALUES # pylint: disable=no-name-in-module +from sqlalchemy.orm.query import Query from superset.advanced_data_type.plugins.internet_address import internet_address from superset.advanced_data_type.plugins.internet_port import internet_port @@ -1502,6 +1504,32 @@ def EMAIL_HEADER_MUTATOR( # pylint: disable=invalid-name,unused-argument }, } + +# Extra related query filters make it possible to limit which objects are shown +# in the UI. For examples, to only show "admin" or users starting with the letter "b" in +# the "Owners" dropdowns, you could add the following in your config: +# def user_filter(query: Query, *args, *kwargs): +# from superset import security_manager +# +# user_model = security_manager.user_model +# filters = [ +# user_model.username == "admin", +# user_model.username.ilike("b%"), +# ] +# return query.filter(or_(*filters)) +# +# EXTRA_RELATED_QUERY_FILTERS = {"user": user_filter} +# +# Similarly, to restrict the roles in the "Roles" dropdown you can provide a custom +# filter callback for the "role" key. +class ExtraRelatedQueryFilters(TypedDict, total=False): + role: Callable[[Query], Query] + user: Callable[[Query], Query] + + +EXTRA_RELATED_QUERY_FILTERS: ExtraRelatedQueryFilters = {} + + # ------------------------------------------------------------------- # * WARNING: STOP EDITING HERE * # ------------------------------------------------------------------- diff --git a/superset/dashboards/api.py b/superset/dashboards/api.py index 615855123d26..64ea637c663d 100644 --- a/superset/dashboards/api.py +++ b/superset/dashboards/api.py @@ -95,7 +95,11 @@ requires_json, statsd_metrics, ) -from superset.views.filters import BaseFilterRelatedUsers, FilterRelatedOwners +from superset.views.filters import ( + BaseFilterRelatedRoles, + BaseFilterRelatedUsers, + FilterRelatedOwners, +) logger = logging.getLogger(__name__) @@ -244,7 +248,9 @@ def ensure_thumbnails_enabled(self) -> Optional[Response]: base_related_field_filters = { "owners": [["id", BaseFilterRelatedUsers, lambda: []]], "created_by": [["id", BaseFilterRelatedUsers, lambda: []]], + "roles": [["id", BaseFilterRelatedRoles, lambda: []]], } + related_field_filters = { "owners": RelatedFieldFilter("first_name", FilterRelatedOwners), "roles": RelatedFieldFilter("name", FilterRelatedRoles), diff --git a/superset/views/filters.py b/superset/views/filters.py index 9450a830332d..625566b98828 100644 --- a/superset/views/filters.py +++ b/superset/views/filters.py @@ -72,11 +72,35 @@ class BaseFilterRelatedUsers(BaseFilter): # pylint: disable=too-few-public-meth arg_name = "username" def apply(self, query: Query, value: Optional[Any]) -> Query: - user_model = security_manager.user_model + if extra_filters := current_app.config["EXTRA_RELATED_QUERY_FILTERS"].get( + "user", + ): + query = extra_filters(query) + exclude_users = ( security_manager.get_exclude_users_from_lists() if current_app.config["EXCLUDE_USERS_FROM_LISTS"] is None else current_app.config["EXCLUDE_USERS_FROM_LISTS"] ) - query_ = query.filter(and_(user_model.username.not_in(exclude_users))) - return query_ + if exclude_users: + user_model = security_manager.user_model + return query.filter(and_(user_model.username.not_in(exclude_users))) + + return query + + +class BaseFilterRelatedRoles(BaseFilter): # pylint: disable=too-few-public-methods + """ + Filter to apply on related roles. + """ + + name = lazy_gettext("role") + arg_name = "role" + + def apply(self, query: Query, value: Optional[Any]) -> Query: + if extra_filters := current_app.config["EXTRA_RELATED_QUERY_FILTERS"].get( + "role", + ): + return extra_filters(query) + + return query diff --git a/tests/integration_tests/base_api_tests.py b/tests/integration_tests/base_api_tests.py index 66544f1447e0..478fee0a0dca 100644 --- a/tests/integration_tests/base_api_tests.py +++ b/tests/integration_tests/base_api_tests.py @@ -219,6 +219,26 @@ def test_get_related_owners(self): for expected_user in expected_users: assert expected_user in response_users + def test_get_related_owners_with_extra_filters(self): + """ + API: Test get related owners with extra related query filters + """ + self.login(username="admin") + + def _base_filter(query): + return query.filter_by(username="alpha") + + with patch.dict( + "superset.views.filters.current_app.config", + {"EXTRA_RELATED_QUERY_FILTERS": {"user": _base_filter}}, + ): + uri = f"api/v1/{self.resource_name}/related/owners" + rv = self.client.get(uri) + assert rv.status_code == 200 + response = json.loads(rv.data.decode("utf-8")) + response_users = [result["text"] for result in response["result"]] + assert response_users == ["alpha user"] + def test_get_related_owners_paginated(self): """ API: Test get related owners with pagination diff --git a/tests/integration_tests/dashboards/api_tests.py b/tests/integration_tests/dashboards/api_tests.py index 1ca80aae38df..10ca16f4e713 100644 --- a/tests/integration_tests/dashboards/api_tests.py +++ b/tests/integration_tests/dashboards/api_tests.py @@ -1828,6 +1828,26 @@ def test_get_filter_related_roles(self): response_roles = [result["text"] for result in response["result"]] assert "Alpha" in response_roles + def test_get_all_related_roles_with_with_extra_filters(self): + """ + API: Test get filter related roles with extra related query filters + """ + self.login(username="admin") + + def _base_filter(query): + return query.filter_by(name="Alpha") + + with patch.dict( + "superset.views.filters.current_app.config", + {"EXTRA_RELATED_QUERY_FILTERS": {"role": _base_filter}}, + ): + uri = f"api/v1/dashboard/related/roles" + rv = self.client.get(uri) + assert rv.status_code == 200 + response = json.loads(rv.data.decode("utf-8")) + response_roles = [result["text"] for result in response["result"]] + assert response_roles == ["Alpha"] + @pytest.mark.usefixtures("load_world_bank_dashboard_with_slices") def test_embedded_dashboards(self): self.login(username="admin")