Skip to content

Commit

Permalink
feat: RLS for SQL Lab (#19999)
Browse files Browse the repository at this point in the history
* feat: RLS for SQL Lab

* Small fixes

* Pass username to security manager

* Update docstrings

* Add tests

* Remove type from docstring
  • Loading branch information
betodealmeida committed May 12, 2022
1 parent ded9122 commit f2881e5
Show file tree
Hide file tree
Showing 8 changed files with 329 additions and 66 deletions.
3 changes: 3 additions & 0 deletions superset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,9 @@ def _try_json_readsha(filepath: str, length: int) -> Optional[str]:
"UX_BETA": False,
"GENERIC_CHART_AXES": False,
"ALLOW_ADHOC_SUBQUERY": False,
# Apply RLS rules to SQL Lab queries. This requires parsing and manipulating the
# query, and might break queries and/or allow users to bypass RLS. Use with care!
"RLS_IN_SQLLAB": False,
}

# Feature flags may also be set via 'SUPERSET_FEATURE_' prefixed environment vars.
Expand Down
16 changes: 10 additions & 6 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1120,20 +1120,24 @@ def is_alias_used_in_orderby(col: ColumnElement) -> bool:
col.name = f"{col.name}__"

def get_sqla_row_level_filters(
self, template_processor: BaseTemplateProcessor
self,
template_processor: BaseTemplateProcessor,
username: Optional[str] = None,
) -> List[TextClause]:
"""
Return the appropriate row level security filters for
this table and the current user.
Return the appropriate row level security filters for this table and the
current user. A custom username can be passed when the user is not present in the
Flask global namespace.
:param BaseTemplateProcessor template_processor: The template
processor to apply to the filters.
:param template_processor: The template processor to apply to the filters.
:param username: Optional username if there's no user in the Flask global
namespace.
:returns: A list of SQL clauses to be ANDed together.
"""
all_filters: List[TextClause] = []
filter_groups: Dict[Union[int, str], List[TextClause]] = defaultdict(list)
try:
for filter_ in security_manager.get_rls_filters(self):
for filter_ in security_manager.get_rls_filters(self, username):
clause = self.text(
f"({template_processor.process_template(filter_.clause)})"
)
Expand Down
119 changes: 66 additions & 53 deletions superset/security/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1135,72 +1135,85 @@ def get_guest_rls_filters(
]
return []

def get_rls_filters(self, table: "BaseDatasource") -> List[SqlaQuery]:
def get_rls_filters(
self,
table: "BaseDatasource",
username: Optional[str] = None,
) -> List[SqlaQuery]:
"""
Retrieves the appropriate row level security filters for the current user and
the passed table.
:param table: The table to check against
:param BaseDatasource table: The table to check against.
:param Optional[str] username: Optional username if there's no user in the Flask
global namespace.
:returns: A list of filters
"""
if hasattr(g, "user"):
# pylint: disable=import-outside-toplevel
from superset.connectors.sqla.models import (
RLSFilterRoles,
RLSFilterTables,
RowLevelSecurityFilter,
)
user = g.user
elif username:
user = self.find_user(username=username)
else:
return []

user_roles = [role.id for role in self.get_user_roles()]
regular_filter_roles = (
self.get_session.query(RLSFilterRoles.c.rls_filter_id)
.join(RowLevelSecurityFilter)
.filter(
RowLevelSecurityFilter.filter_type
== RowLevelSecurityFilterType.REGULAR
)
.filter(RLSFilterRoles.c.role_id.in_(user_roles))
.subquery()
# pylint: disable=import-outside-toplevel
from superset.connectors.sqla.models import (
RLSFilterRoles,
RLSFilterTables,
RowLevelSecurityFilter,
)

user_roles = [role.id for role in self.get_user_roles(user)]
regular_filter_roles = (
self.get_session()
.query(RLSFilterRoles.c.rls_filter_id)
.join(RowLevelSecurityFilter)
.filter(
RowLevelSecurityFilter.filter_type == RowLevelSecurityFilterType.REGULAR
)
base_filter_roles = (
self.get_session.query(RLSFilterRoles.c.rls_filter_id)
.join(RowLevelSecurityFilter)
.filter(
RowLevelSecurityFilter.filter_type
== RowLevelSecurityFilterType.BASE
)
.filter(RLSFilterRoles.c.role_id.in_(user_roles))
.subquery()
.filter(RLSFilterRoles.c.role_id.in_(user_roles))
.subquery()
)
base_filter_roles = (
self.get_session()
.query(RLSFilterRoles.c.rls_filter_id)
.join(RowLevelSecurityFilter)
.filter(
RowLevelSecurityFilter.filter_type == RowLevelSecurityFilterType.BASE
)
filter_tables = (
self.get_session.query(RLSFilterTables.c.rls_filter_id)
.filter(RLSFilterTables.c.table_id == table.id)
.subquery()
.filter(RLSFilterRoles.c.role_id.in_(user_roles))
.subquery()
)
filter_tables = (
self.get_session()
.query(RLSFilterTables.c.rls_filter_id)
.filter(RLSFilterTables.c.table_id == table.id)
.subquery()
)
query = (
self.get_session()
.query(
RowLevelSecurityFilter.id,
RowLevelSecurityFilter.group_key,
RowLevelSecurityFilter.clause,
)
query = (
self.get_session.query(
RowLevelSecurityFilter.id,
RowLevelSecurityFilter.group_key,
RowLevelSecurityFilter.clause,
)
.filter(RowLevelSecurityFilter.id.in_(filter_tables))
.filter(
or_(
and_(
RowLevelSecurityFilter.filter_type
== RowLevelSecurityFilterType.REGULAR,
RowLevelSecurityFilter.id.in_(regular_filter_roles),
),
and_(
RowLevelSecurityFilter.filter_type
== RowLevelSecurityFilterType.BASE,
RowLevelSecurityFilter.id.notin_(base_filter_roles),
),
)
.filter(RowLevelSecurityFilter.id.in_(filter_tables))
.filter(
or_(
and_(
RowLevelSecurityFilter.filter_type
== RowLevelSecurityFilterType.REGULAR,
RowLevelSecurityFilter.id.in_(regular_filter_roles),
),
and_(
RowLevelSecurityFilter.filter_type
== RowLevelSecurityFilterType.BASE,
RowLevelSecurityFilter.id.notin_(base_filter_roles),
),
)
)
return query.all()
return []
)
return query.all()

def get_rls_ids(self, table: "BaseDatasource") -> List[int]:
"""
Expand Down
26 changes: 23 additions & 3 deletions superset/sql_lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,13 @@
from flask_babel import gettext as __
from sqlalchemy.orm import Session

from superset import app, results_backend, results_backend_use_msgpack, security_manager
from superset import (
app,
is_feature_enabled,
results_backend,
results_backend_use_msgpack,
security_manager,
)
from superset.common.db_query_status import QueryStatus
from superset.dataframe import df_to_records
from superset.db_engine_specs import BaseEngineSpec
Expand All @@ -41,7 +47,7 @@
from superset.models.core import Database
from superset.models.sql_lab import Query
from superset.result_set import SupersetResultSet
from superset.sql_parse import CtasMethod, ParsedQuery
from superset.sql_parse import CtasMethod, insert_rls, ParsedQuery
from superset.sqllab.limiting_factor import LimitingFactor
from superset.utils.celery import session_scope
from superset.utils.core import json_iso_dttm_ser, QuerySource, zlib_compress
Expand Down Expand Up @@ -176,7 +182,7 @@ def get_sql_results( # pylint: disable=too-many-arguments
return handle_query_error(ex, query, session)


def execute_sql_statement( # pylint: disable=too-many-arguments,too-many-locals
def execute_sql_statement( # pylint: disable=too-many-arguments,too-many-locals,too-many-statements
sql_statement: str,
query: Query,
user_name: Optional[str],
Expand All @@ -188,7 +194,21 @@ def execute_sql_statement( # pylint: disable=too-many-arguments,too-many-locals
"""Executes a single SQL statement"""
database: Database = query.database
db_engine_spec = database.db_engine_spec

parsed_query = ParsedQuery(sql_statement)
if is_feature_enabled("RLS_IN_SQLLAB"):
# Insert any applicable RLS predicates
parsed_query = ParsedQuery(
str(
insert_rls(
parsed_query._parsed[0], # pylint: disable=protected-access
database.id,
query.schema,
username=user_name,
)
)
)

sql = parsed_query.stripped()
# This is a test to see if the query is being
# limited by either the dropdown or the sql.
Expand Down
6 changes: 4 additions & 2 deletions superset/sql_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,7 @@ def get_rls_for_table(
candidate: Token,
database_id: int,
default_schema: Optional[str],
username: Optional[str] = None,
) -> Optional[TokenList]:
"""
Given a table name, return any associated RLS predicates.
Expand Down Expand Up @@ -585,7 +586,7 @@ def get_rls_for_table(
template_processor = dataset.get_template_processor()
predicate = " AND ".join(
str(filter_)
for filter_ in dataset.get_sqla_row_level_filters(template_processor)
for filter_ in dataset.get_sqla_row_level_filters(template_processor, username)
)
if not predicate:
return None
Expand All @@ -600,6 +601,7 @@ def insert_rls(
token_list: TokenList,
database_id: int,
default_schema: Optional[str],
username: Optional[str] = None,
) -> TokenList:
"""
Update a statement inplace applying any associated RLS predicates.
Expand All @@ -621,7 +623,7 @@ def insert_rls(
elif state == InsertRLSState.SEEN_SOURCE and (
isinstance(token, Identifier) or token.ttype == Keyword
):
rls = get_rls_for_table(token, database_id, default_schema)
rls = get_rls_for_table(token, database_id, default_schema, username)
if rls:
state = InsertRLSState.FOUND_TABLE

Expand Down
2 changes: 1 addition & 1 deletion superset/views/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2101,7 +2101,7 @@ def sqllab_viz(self) -> FlaskResponse: # pylint: disable=no-self-use
@has_access
@expose("/extra_table_metadata/<int:database_id>/<table_name>/<schema>/")
@event_logger.log_this
def extra_table_metadata( # pylint: disable=no-self-use
def extra_table_metadata(
self, database_id: int, table_name: str, schema: str
) -> FlaskResponse:
logger.warning(
Expand Down

0 comments on commit f2881e5

Please sign in to comment.