diff --git a/superset/config.py b/superset/config.py index ee15e32b3b44..b507d67eae47 100644 --- a/superset/config.py +++ b/superset/config.py @@ -412,6 +412,9 @@ def _try_json_readsha(filepath: str, length: int) -> Optional[str]: "UX_BETA": False, "GENERIC_CHART_AXES": False, "ALLOW_ADHOC_SUBQUERY": False, + # Apply RLS rules to SQL Lab queries. This requires parsing and manipulating the + # query, and might break queries and/or allow users to bypass RLS. Use with care! + "RLS_IN_SQLLAB": False, } # Feature flags may also be set via 'SUPERSET_FEATURE_' prefixed environment vars. diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 94190f678b48..489c483baf62 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -1120,20 +1120,24 @@ def is_alias_used_in_orderby(col: ColumnElement) -> bool: col.name = f"{col.name}__" def get_sqla_row_level_filters( - self, template_processor: BaseTemplateProcessor + self, + template_processor: BaseTemplateProcessor, + username: Optional[str] = None, ) -> List[TextClause]: """ - Return the appropriate row level security filters for - this table and the current user. + Return the appropriate row level security filters for this table and the + current user. A custom username can be passed when the user is not present in the + Flask global namespace. - :param BaseTemplateProcessor template_processor: The template - processor to apply to the filters. + :param template_processor: The template processor to apply to the filters. + :param username: Optional username if there's no user in the Flask global + namespace. :returns: A list of SQL clauses to be ANDed together. """ all_filters: List[TextClause] = [] filter_groups: Dict[Union[int, str], List[TextClause]] = defaultdict(list) try: - for filter_ in security_manager.get_rls_filters(self): + for filter_ in security_manager.get_rls_filters(self, username): clause = self.text( f"({template_processor.process_template(filter_.clause)})" ) diff --git a/superset/security/manager.py b/superset/security/manager.py index c422c2e8bbe3..44c53329ade6 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -1135,72 +1135,85 @@ def get_guest_rls_filters( ] return [] - def get_rls_filters(self, table: "BaseDatasource") -> List[SqlaQuery]: + def get_rls_filters( + self, + table: "BaseDatasource", + username: Optional[str] = None, + ) -> List[SqlaQuery]: """ Retrieves the appropriate row level security filters for the current user and the passed table. - :param table: The table to check against + :param BaseDatasource table: The table to check against. + :param Optional[str] username: Optional username if there's no user in the Flask + global namespace. :returns: A list of filters """ if hasattr(g, "user"): - # pylint: disable=import-outside-toplevel - from superset.connectors.sqla.models import ( - RLSFilterRoles, - RLSFilterTables, - RowLevelSecurityFilter, - ) + user = g.user + elif username: + user = self.find_user(username=username) + else: + return [] - user_roles = [role.id for role in self.get_user_roles()] - regular_filter_roles = ( - self.get_session.query(RLSFilterRoles.c.rls_filter_id) - .join(RowLevelSecurityFilter) - .filter( - RowLevelSecurityFilter.filter_type - == RowLevelSecurityFilterType.REGULAR - ) - .filter(RLSFilterRoles.c.role_id.in_(user_roles)) - .subquery() + # pylint: disable=import-outside-toplevel + from superset.connectors.sqla.models import ( + RLSFilterRoles, + RLSFilterTables, + RowLevelSecurityFilter, + ) + + user_roles = [role.id for role in self.get_user_roles(user)] + regular_filter_roles = ( + self.get_session() + .query(RLSFilterRoles.c.rls_filter_id) + .join(RowLevelSecurityFilter) + .filter( + RowLevelSecurityFilter.filter_type == RowLevelSecurityFilterType.REGULAR ) - base_filter_roles = ( - self.get_session.query(RLSFilterRoles.c.rls_filter_id) - .join(RowLevelSecurityFilter) - .filter( - RowLevelSecurityFilter.filter_type - == RowLevelSecurityFilterType.BASE - ) - .filter(RLSFilterRoles.c.role_id.in_(user_roles)) - .subquery() + .filter(RLSFilterRoles.c.role_id.in_(user_roles)) + .subquery() + ) + base_filter_roles = ( + self.get_session() + .query(RLSFilterRoles.c.rls_filter_id) + .join(RowLevelSecurityFilter) + .filter( + RowLevelSecurityFilter.filter_type == RowLevelSecurityFilterType.BASE ) - filter_tables = ( - self.get_session.query(RLSFilterTables.c.rls_filter_id) - .filter(RLSFilterTables.c.table_id == table.id) - .subquery() + .filter(RLSFilterRoles.c.role_id.in_(user_roles)) + .subquery() + ) + filter_tables = ( + self.get_session() + .query(RLSFilterTables.c.rls_filter_id) + .filter(RLSFilterTables.c.table_id == table.id) + .subquery() + ) + query = ( + self.get_session() + .query( + RowLevelSecurityFilter.id, + RowLevelSecurityFilter.group_key, + RowLevelSecurityFilter.clause, ) - query = ( - self.get_session.query( - RowLevelSecurityFilter.id, - RowLevelSecurityFilter.group_key, - RowLevelSecurityFilter.clause, - ) - .filter(RowLevelSecurityFilter.id.in_(filter_tables)) - .filter( - or_( - and_( - RowLevelSecurityFilter.filter_type - == RowLevelSecurityFilterType.REGULAR, - RowLevelSecurityFilter.id.in_(regular_filter_roles), - ), - and_( - RowLevelSecurityFilter.filter_type - == RowLevelSecurityFilterType.BASE, - RowLevelSecurityFilter.id.notin_(base_filter_roles), - ), - ) + .filter(RowLevelSecurityFilter.id.in_(filter_tables)) + .filter( + or_( + and_( + RowLevelSecurityFilter.filter_type + == RowLevelSecurityFilterType.REGULAR, + RowLevelSecurityFilter.id.in_(regular_filter_roles), + ), + and_( + RowLevelSecurityFilter.filter_type + == RowLevelSecurityFilterType.BASE, + RowLevelSecurityFilter.id.notin_(base_filter_roles), + ), ) ) - return query.all() - return [] + ) + return query.all() def get_rls_ids(self, table: "BaseDatasource") -> List[int]: """ diff --git a/superset/sql_lab.py b/superset/sql_lab.py index 49de3df95752..3f43c0faae66 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -31,7 +31,13 @@ from flask_babel import gettext as __ from sqlalchemy.orm import Session -from superset import app, results_backend, results_backend_use_msgpack, security_manager +from superset import ( + app, + is_feature_enabled, + results_backend, + results_backend_use_msgpack, + security_manager, +) from superset.common.db_query_status import QueryStatus from superset.dataframe import df_to_records from superset.db_engine_specs import BaseEngineSpec @@ -41,7 +47,7 @@ from superset.models.core import Database from superset.models.sql_lab import Query from superset.result_set import SupersetResultSet -from superset.sql_parse import CtasMethod, ParsedQuery +from superset.sql_parse import CtasMethod, insert_rls, ParsedQuery from superset.sqllab.limiting_factor import LimitingFactor from superset.utils.celery import session_scope from superset.utils.core import json_iso_dttm_ser, QuerySource, zlib_compress @@ -176,7 +182,7 @@ def get_sql_results( # pylint: disable=too-many-arguments return handle_query_error(ex, query, session) -def execute_sql_statement( # pylint: disable=too-many-arguments,too-many-locals +def execute_sql_statement( # pylint: disable=too-many-arguments,too-many-locals,too-many-statements sql_statement: str, query: Query, user_name: Optional[str], @@ -188,7 +194,21 @@ def execute_sql_statement( # pylint: disable=too-many-arguments,too-many-locals """Executes a single SQL statement""" database: Database = query.database db_engine_spec = database.db_engine_spec + parsed_query = ParsedQuery(sql_statement) + if is_feature_enabled("RLS_IN_SQLLAB"): + # Insert any applicable RLS predicates + parsed_query = ParsedQuery( + str( + insert_rls( + parsed_query._parsed[0], # pylint: disable=protected-access + database.id, + query.schema, + username=user_name, + ) + ) + ) + sql = parsed_query.stripped() # This is a test to see if the query is being # limited by either the dropdown or the sql. diff --git a/superset/sql_parse.py b/superset/sql_parse.py index d377986f5657..b585810f785a 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -553,6 +553,7 @@ def get_rls_for_table( candidate: Token, database_id: int, default_schema: Optional[str], + username: Optional[str] = None, ) -> Optional[TokenList]: """ Given a table name, return any associated RLS predicates. @@ -585,7 +586,7 @@ def get_rls_for_table( template_processor = dataset.get_template_processor() predicate = " AND ".join( str(filter_) - for filter_ in dataset.get_sqla_row_level_filters(template_processor) + for filter_ in dataset.get_sqla_row_level_filters(template_processor, username) ) if not predicate: return None @@ -600,6 +601,7 @@ def insert_rls( token_list: TokenList, database_id: int, default_schema: Optional[str], + username: Optional[str] = None, ) -> TokenList: """ Update a statement inplace applying any associated RLS predicates. @@ -621,7 +623,7 @@ def insert_rls( elif state == InsertRLSState.SEEN_SOURCE and ( isinstance(token, Identifier) or token.ttype == Keyword ): - rls = get_rls_for_table(token, database_id, default_schema) + rls = get_rls_for_table(token, database_id, default_schema, username) if rls: state = InsertRLSState.FOUND_TABLE diff --git a/superset/views/core.py b/superset/views/core.py index 9ced8b485ff3..660c527c0a88 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -2101,7 +2101,7 @@ def sqllab_viz(self) -> FlaskResponse: # pylint: disable=no-self-use @has_access @expose("/extra_table_metadata////") @event_logger.log_this - def extra_table_metadata( # pylint: disable=no-self-use + def extra_table_metadata( self, database_id: int, table_name: str, schema: str ) -> FlaskResponse: logger.warning( diff --git a/tests/unit_tests/sql_lab_test.py b/tests/unit_tests/sql_lab_test.py new file mode 100644 index 000000000000..dcc2ded750fb --- /dev/null +++ b/tests/unit_tests/sql_lab_test.py @@ -0,0 +1,218 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=import-outside-toplevel, invalid-name, unused-argument, too-many-locals + +import sqlparse +from pytest_mock import MockerFixture +from sqlalchemy.orm.session import Session + + +def test_execute_sql_statement(mocker: MockerFixture, app: None) -> None: + """ + Simple test for `execute_sql_statement`. + """ + from superset.sql_lab import execute_sql_statement + + sql_statement = "SELECT 42 AS answer" + + query = mocker.MagicMock() + query.limit = 1 + query.select_as_cta_used = False + database = query.database + database.allow_dml = False + database.apply_limit_to_sql.return_value = "SELECT 42 AS answer LIMIT 2" + db_engine_spec = database.db_engine_spec + db_engine_spec.is_select_query.return_value = True + db_engine_spec.fetch_data.return_value = [(42,)] + + session = mocker.MagicMock() + cursor = mocker.MagicMock() + SupersetResultSet = mocker.patch("superset.sql_lab.SupersetResultSet") + + execute_sql_statement( + sql_statement, + query, + user_name=None, + session=session, + cursor=cursor, + log_params={}, + apply_ctas=False, + ) + + database.apply_limit_to_sql.assert_called_with("SELECT 42 AS answer", 2, force=True) + db_engine_spec.execute.assert_called_with( + cursor, "SELECT 42 AS answer LIMIT 2", async_=True + ) + SupersetResultSet.assert_called_with([(42,)], cursor.description, db_engine_spec) + + +def test_execute_sql_statement_with_rls( + mocker: MockerFixture, + app_context: None, +) -> None: + """ + Test for `execute_sql_statement` when an RLS rule is in place. + """ + from superset.sql_lab import execute_sql_statement + + sql_statement = "SELECT * FROM sales" + + query = mocker.MagicMock() + query.limit = 100 + query.select_as_cta_used = False + database = query.database + database.allow_dml = False + database.apply_limit_to_sql.return_value = ( + "SELECT * FROM sales WHERE organization_id=42 LIMIT 101" + ) + db_engine_spec = database.db_engine_spec + db_engine_spec.is_select_query.return_value = True + db_engine_spec.fetch_data.return_value = [(42,)] + + session = mocker.MagicMock() + cursor = mocker.MagicMock() + SupersetResultSet = mocker.patch("superset.sql_lab.SupersetResultSet") + mocker.patch( + "superset.sql_lab.insert_rls", + return_value=sqlparse.parse("SELECT * FROM sales WHERE organization_id=42")[0], + ) + mocker.patch("superset.sql_lab.is_feature_enabled", return_value=True) + + execute_sql_statement( + sql_statement, + query, + user_name=None, + session=session, + cursor=cursor, + log_params={}, + apply_ctas=False, + ) + + database.apply_limit_to_sql.assert_called_with( + "SELECT * FROM sales WHERE organization_id=42", + 101, + force=True, + ) + db_engine_spec.execute.assert_called_with( + cursor, + "SELECT * FROM sales WHERE organization_id=42 LIMIT 101", + async_=True, + ) + SupersetResultSet.assert_called_with([(42,)], cursor.description, db_engine_spec) + + +def test_sql_lab_insert_rls( + mocker: MockerFixture, + session: Session, + app_context: None, +) -> None: + """ + Integration test for `insert_rls`. + """ + from flask_appbuilder.security.sqla.models import Role, User + + from superset.connectors.sqla.models import RowLevelSecurityFilter, SqlaTable + from superset.models.core import Database + from superset.models.sql_lab import Query + from superset.security.manager import SupersetSecurityManager + from superset.sql_lab import execute_sql_statement + from superset.utils.core import RowLevelSecurityFilterType + + engine = session.connection().engine + Query.metadata.create_all(engine) # pylint: disable=no-member + + connection = engine.raw_connection() + connection.execute("CREATE TABLE t (c INTEGER)") + for i in range(10): + connection.execute("INSERT INTO t VALUES (?)", (i,)) + + cursor = connection.cursor() + + query = Query( + sql="SELECT c FROM t", + client_id="abcde", + database=Database(database_name="test_db", sqlalchemy_uri="sqlite://"), + schema=None, + limit=5, + select_as_cta_used=False, + ) + session.add(query) + session.commit() + + # first without RLS + superset_result_set = execute_sql_statement( + sql_statement=query.sql, + query=query, + user_name="admin", + session=session, + cursor=cursor, + log_params=None, + apply_ctas=False, + ) + assert ( + superset_result_set.to_pandas_df().to_markdown() + == """ +| | c | +|---:|----:| +| 0 | 0 | +| 1 | 1 | +| 2 | 2 | +| 3 | 3 | +| 4 | 4 |""".strip() + ) + assert query.executed_sql == "SELECT c FROM t\nLIMIT 6" + + # now with RLS + admin = User( + first_name="Alice", + last_name="Doe", + email="adoe@example.org", + username="admin", + roles=[Role(name="Admin")], + ) + rls = RowLevelSecurityFilter( + filter_type=RowLevelSecurityFilterType.REGULAR, + tables=[SqlaTable(database_id=1, schema=None, table_name="t")], + roles=[admin.roles[0]], + group_key=None, + clause="c > 5", + ) + session.add(rls) + session.flush() + mocker.patch.object(SupersetSecurityManager, "find_user", return_value=admin) + mocker.patch("superset.sql_lab.is_feature_enabled", return_value=True) + + superset_result_set = execute_sql_statement( + sql_statement=query.sql, + query=query, + user_name="admin", + session=session, + cursor=cursor, + log_params=None, + apply_ctas=False, + ) + assert ( + superset_result_set.to_pandas_df().to_markdown() + == """ +| | c | +|---:|----:| +| 0 | 6 | +| 1 | 7 | +| 2 | 8 | +| 3 | 9 |""".strip() + ) + assert query.executed_sql == "SELECT c FROM t WHERE (t.c > 5)\nLIMIT 6" diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py index d9c5d64c5950..1d2c788496af 100644 --- a/tests/unit_tests/sql_parse_tests.py +++ b/tests/unit_tests/sql_parse_tests.py @@ -1406,7 +1406,10 @@ def test_insert_rls( # pylint: disable=unused-argument def get_rls_for_table( - candidate: Token, database_id: int, default_schema: str + candidate: Token, + database_id: int, + default_schema: str, + username: Optional[str] = None, ) -> Optional[TokenList]: """ Return the RLS ``condition`` if ``candidate`` matches ``table``.