Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
betodealmeida committed May 12, 2022
1 parent 12737e2 commit 394a814
Show file tree
Hide file tree
Showing 2 changed files with 226 additions and 4 deletions.
12 changes: 8 additions & 4 deletions superset/security/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1165,7 +1165,8 @@ def get_rls_filters(

user_roles = [role.id for role in self.get_user_roles(user)]
regular_filter_roles = (
self.get_session.query(RLSFilterRoles.c.rls_filter_id)
self.get_session()
.query(RLSFilterRoles.c.rls_filter_id)
.join(RowLevelSecurityFilter)
.filter(
RowLevelSecurityFilter.filter_type == RowLevelSecurityFilterType.REGULAR
Expand All @@ -1174,7 +1175,8 @@ def get_rls_filters(
.subquery()
)
base_filter_roles = (
self.get_session.query(RLSFilterRoles.c.rls_filter_id)
self.get_session()
.query(RLSFilterRoles.c.rls_filter_id)
.join(RowLevelSecurityFilter)
.filter(
RowLevelSecurityFilter.filter_type == RowLevelSecurityFilterType.BASE
Expand All @@ -1183,12 +1185,14 @@ def get_rls_filters(
.subquery()
)
filter_tables = (
self.get_session.query(RLSFilterTables.c.rls_filter_id)
self.get_session()
.query(RLSFilterTables.c.rls_filter_id)
.filter(RLSFilterTables.c.table_id == table.id)
.subquery()
)
query = (
self.get_session.query(
self.get_session()
.query(
RowLevelSecurityFilter.id,
RowLevelSecurityFilter.group_key,
RowLevelSecurityFilter.clause,
Expand Down
218 changes: 218 additions & 0 deletions tests/unit_tests/sql_lab_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=import-outside-toplevel, invalid-name, unused-argument, too-many-locals

import sqlparse
from pytest_mock import MockerFixture
from sqlalchemy.orm.session import Session


def test_execute_sql_statement(mocker: MockerFixture, app: None) -> None:
"""
Simple test for `execute_sql_statement`.
"""
from superset.sql_lab import execute_sql_statement

sql_statement = "SELECT 42 AS answer"

query = mocker.MagicMock()
query.limit = 1
query.select_as_cta_used = False
database = query.database
database.allow_dml = False
database.apply_limit_to_sql.return_value = "SELECT 42 AS answer LIMIT 2"
db_engine_spec = database.db_engine_spec
db_engine_spec.is_select_query.return_value = True
db_engine_spec.fetch_data.return_value = [(42,)]

session = mocker.MagicMock()
cursor = mocker.MagicMock()
SupersetResultSet = mocker.patch("superset.sql_lab.SupersetResultSet")

execute_sql_statement(
sql_statement,
query,
user_name=None,
session=session,
cursor=cursor,
log_params={},
apply_ctas=False,
)

database.apply_limit_to_sql.assert_called_with("SELECT 42 AS answer", 2, force=True)
db_engine_spec.execute.assert_called_with(
cursor, "SELECT 42 AS answer LIMIT 2", async_=True
)
SupersetResultSet.assert_called_with([(42,)], cursor.description, db_engine_spec)


def test_execute_sql_statement_with_rls(
mocker: MockerFixture,
app_context: None,
) -> None:
"""
Test for `execute_sql_statement` when an RLS rule is in place.
"""
from superset.sql_lab import execute_sql_statement

sql_statement = "SELECT * FROM sales"

query = mocker.MagicMock()
query.limit = 100
query.select_as_cta_used = False
database = query.database
database.allow_dml = False
database.apply_limit_to_sql.return_value = (
"SELECT * FROM sales WHERE organization_id=42 LIMIT 101"
)
db_engine_spec = database.db_engine_spec
db_engine_spec.is_select_query.return_value = True
db_engine_spec.fetch_data.return_value = [(42,)]

session = mocker.MagicMock()
cursor = mocker.MagicMock()
SupersetResultSet = mocker.patch("superset.sql_lab.SupersetResultSet")
mocker.patch(
"superset.sql_lab.insert_rls",
return_value=sqlparse.parse("SELECT * FROM sales WHERE organization_id=42")[0],
)
mocker.patch("superset.sql_lab.is_feature_enabled", return_value=True)

execute_sql_statement(
sql_statement,
query,
user_name=None,
session=session,
cursor=cursor,
log_params={},
apply_ctas=False,
)

database.apply_limit_to_sql.assert_called_with(
"SELECT * FROM sales WHERE organization_id=42",
101,
force=True,
)
db_engine_spec.execute.assert_called_with(
cursor,
"SELECT * FROM sales WHERE organization_id=42 LIMIT 101",
async_=True,
)
SupersetResultSet.assert_called_with([(42,)], cursor.description, db_engine_spec)


def test_sql_lab_insert_rls(
mocker: MockerFixture,
session: Session,
app_context: None,
) -> None:
"""
Integration test for `insert_rls`.
"""
from flask_appbuilder.security.sqla.models import Role, User

from superset.connectors.sqla.models import RowLevelSecurityFilter, SqlaTable
from superset.models.core import Database
from superset.models.sql_lab import Query
from superset.security.manager import SupersetSecurityManager
from superset.sql_lab import execute_sql_statement
from superset.utils.core import RowLevelSecurityFilterType

engine = session.connection().engine
Query.metadata.create_all(engine) # pylint: disable=no-member

connection = engine.raw_connection()
connection.execute("CREATE TABLE t (c INTEGER)")
for i in range(10):
connection.execute("INSERT INTO t VALUES (?)", (i,))

cursor = connection.cursor()

query = Query(
sql="SELECT c FROM t",
client_id="abcde",
database=Database(database_name="test_db", sqlalchemy_uri="sqlite://"),
schema=None,
limit=5,
select_as_cta_used=False,
)
session.add(query)
session.commit()

# first without RLS
superset_result_set = execute_sql_statement(
sql_statement=query.sql,
query=query,
user_name="admin",
session=session,
cursor=cursor,
log_params=None,
apply_ctas=False,
)
assert (
superset_result_set.to_pandas_df().to_markdown()
== """
| | c |
|---:|----:|
| 0 | 0 |
| 1 | 1 |
| 2 | 2 |
| 3 | 3 |
| 4 | 4 |""".strip()
)
assert query.executed_sql == "SELECT c FROM t\nLIMIT 6"

# now with RLS
admin = User(
first_name="Alice",
last_name="Doe",
email="adoe@example.org",
username="admin",
roles=[Role(name="Admin")],
)
rls = RowLevelSecurityFilter(
filter_type=RowLevelSecurityFilterType.REGULAR,
tables=[SqlaTable(database_id=1, schema=None, table_name="t")],
roles=[admin.roles[0]],
group_key=None,
clause="c > 5",
)
session.add(rls)
session.flush()
mocker.patch.object(SupersetSecurityManager, "find_user", return_value=admin)
mocker.patch("superset.sql_lab.is_feature_enabled", return_value=True)

superset_result_set = execute_sql_statement(
sql_statement=query.sql,
query=query,
user_name="admin",
session=session,
cursor=cursor,
log_params=None,
apply_ctas=False,
)
assert (
superset_result_set.to_pandas_df().to_markdown()
== """
| | c |
|---:|----:|
| 0 | 6 |
| 1 | 7 |
| 2 | 8 |
| 3 | 9 |""".strip()
)
assert query.executed_sql == "SELECT c FROM t WHERE (t.c > 5)\nLIMIT 6"

0 comments on commit 394a814

Please sign in to comment.