Skip to content

Commit

Permalink
feat: add name, description and non null tables to RLS (#20432)
Browse files Browse the repository at this point in the history
* feat: add name, description and non null tables to RLS

* add validation

* add and fix tests

* fix sqlite migration

* improve default value for name
  • Loading branch information
dpgaspar committed Jun 20, 2022
1 parent 8b0bee5 commit 60eb109
Show file tree
Hide file tree
Showing 5 changed files with 214 additions and 7 deletions.
3 changes: 2 additions & 1 deletion superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2482,6 +2482,8 @@ class RowLevelSecurityFilter(Model, AuditMixinNullable):

__tablename__ = "row_level_security_filters"
id = Column(Integer, primary_key=True)
name = Column(String(255), unique=True, nullable=False)
description = Column(Text)
filter_type = Column(
Enum(*[filter_type.value for filter_type in utils.RowLevelSecurityFilterType])
)
Expand All @@ -2494,5 +2496,4 @@ class RowLevelSecurityFilter(Model, AuditMixinNullable):
tables = relationship(
SqlaTable, secondary=RLSFilterTables, backref="row_level_security_filters"
)

clause = Column(Text, nullable=False)
45 changes: 40 additions & 5 deletions superset/connectors/sqla/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from flask_appbuilder.security.decorators import has_access
from flask_babel import lazy_gettext as _
from wtforms.ext.sqlalchemy.fields import QuerySelectField
from wtforms.validators import Regexp
from wtforms.validators import DataRequired, Regexp

from superset import app, db
from superset.connectors.base.views import DatasourceModelView
Expand All @@ -47,6 +47,19 @@
logger = logging.getLogger(__name__)


class SelectDataRequired(DataRequired): # pylint: disable=too-few-public-methods
"""
Select required flag on the input field will not work well on Chrome
Console error:
An invalid form control with name='tables' is not focusable.
This makes a simple override to the DataRequired to be used specifically with
select fields
"""

field_flags = ()


class TableColumnInlineView(CompactCRUDMixin, SupersetModelView):
datamodel = SQLAInterface(models.TableColumn)
# TODO TODO, review need for this on related_views
Expand Down Expand Up @@ -272,21 +285,39 @@ class RowLevelSecurityFiltersModelView(SupersetModelView, DeleteMixin):
edit_title = _("Edit Row level security filter")

list_columns = [
"name",
"filter_type",
"tables",
"roles",
"group_key",
"clause",
"creator",
"modified",
]
order_columns = ["filter_type", "group_key", "clause", "modified"]
edit_columns = ["filter_type", "tables", "roles", "group_key", "clause"]
order_columns = ["name", "filter_type", "clause", "modified"]
edit_columns = [
"name",
"description",
"filter_type",
"tables",
"roles",
"group_key",
"clause",
]
show_columns = edit_columns
search_columns = ("filter_type", "tables", "roles", "group_key", "clause")
search_columns = (
"name",
"description",
"filter_type",
"tables",
"roles",
"group_key",
"clause",
)
add_columns = edit_columns
base_order = ("changed_on", "desc")
description_columns = {
"name": _("Choose a unique name"),
"description": _("Optionally add a detailed description"),
"filter_type": _(
"Regular filters add where clauses to queries if a user belongs to a "
"role referenced in the filter. Base filters apply filters to all queries "
Expand Down Expand Up @@ -319,12 +350,16 @@ class RowLevelSecurityFiltersModelView(SupersetModelView, DeleteMixin):
),
}
label_columns = {
"name": _("Name"),
"description": _("Description"),
"tables": _("Tables"),
"roles": _("Roles"),
"clause": _("Clause"),
"creator": _("Creator"),
"modified": _("Modified"),
}
validators_columns = {"tables": [SelectDataRequired()]}

if app.config["RLS_FORM_QUERY_REL_FIELDS"]:
add_form_query_rel_fields = app.config["RLS_FORM_QUERY_REL_FIELDS"]
edit_form_query_rel_fields = add_form_query_rel_fields
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""add_unique_name_desc_rls
Revision ID: f3afaf1f11f0
Revises: e786798587de
Create Date: 2022-06-19 16:17:23.318618
"""

# revision identifiers, used by Alembic.
revision = "f3afaf1f11f0"
down_revision = "e786798587de"

import sqlalchemy as sa
from alembic import op
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session

Base = declarative_base()


class RowLevelSecurityFilter(Base):
__tablename__ = "row_level_security_filters"
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String(255), unique=True, nullable=False)


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
bind = op.get_bind()
session = Session(bind=bind)

op.add_column(
"row_level_security_filters", sa.Column("name", sa.String(length=255))
)
op.add_column(
"row_level_security_filters", sa.Column("description", sa.Text(), nullable=True)
)

# Set initial default names make sure we can have unique non null values
all_rls = session.query(RowLevelSecurityFilter).all()
for rls in all_rls:
rls.name = f"rls-{rls.id}"
session.commit()

# Now it's safe so set non-null and unique
# add unique constraint
with op.batch_alter_table("row_level_security_filters") as batch_op:
# batch mode is required for sqlite
batch_op.alter_column(
"name",
existing_type=sa.String(255),
nullable=False,
)
batch_op.create_unique_constraint("uq_rls_name", ["name"])
# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint("uq_rls_name", "row_level_security_filters", type_="unique")
op.drop_column("row_level_security_filters", "description")
op.drop_column("row_level_security_filters", "name")
# ### end Alembic commands ###
93 changes: 92 additions & 1 deletion tests/integration_tests/security/row_level_security_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from superset import db, security_manager
from superset.connectors.sqla.models import RowLevelSecurityFilter, SqlaTable
from superset.security.guest_token import (
GuestTokenRlsRule,
GuestTokenResourceType,
GuestUser,
)
Expand Down Expand Up @@ -82,6 +81,7 @@ def setUp(self):

# Create regular RowLevelSecurityFilter (energy_usage, unicode_test)
self.rls_entry1 = RowLevelSecurityFilter()
self.rls_entry1.name = "rls_entry1"
self.rls_entry1.tables.extend(
session.query(SqlaTable)
.filter(SqlaTable.table_name.in_(["energy_usage", "unicode_test"]))
Expand All @@ -96,6 +96,7 @@ def setUp(self):

# Create regular RowLevelSecurityFilter (birth_names name starts with A or B)
self.rls_entry2 = RowLevelSecurityFilter()
self.rls_entry2.name = "rls_entry2"
self.rls_entry2.tables.extend(
session.query(SqlaTable)
.filter(SqlaTable.table_name.in_(["birth_names"]))
Expand All @@ -109,6 +110,7 @@ def setUp(self):

# Create Regular RowLevelSecurityFilter (birth_names name starts with Q)
self.rls_entry3 = RowLevelSecurityFilter()
self.rls_entry3.name = "rls_entry3"
self.rls_entry3.tables.extend(
session.query(SqlaTable)
.filter(SqlaTable.table_name.in_(["birth_names"]))
Expand All @@ -122,6 +124,7 @@ def setUp(self):

# Create Base RowLevelSecurityFilter (birth_names boys)
self.rls_entry4 = RowLevelSecurityFilter()
self.rls_entry4.name = "rls_entry4"
self.rls_entry4.tables.extend(
session.query(SqlaTable)
.filter(SqlaTable.table_name.in_(["birth_names"]))
Expand All @@ -146,6 +149,94 @@ def tearDown(self):
session.delete(self.get_user("NoRlsRoleUser"))
session.commit()

@pytest.fixture()
def create_dataset(self):
with self.create_app().app_context():

dataset = SqlaTable(database_id=1, schema=None, table_name="table1")
db.session.add(dataset)
db.session.flush()
db.session.commit()

yield dataset

# rollback changes (assuming cascade delete)
db.session.delete(dataset)
db.session.commit()

def _get_test_dataset(self):
return (
db.session.query(SqlaTable).filter(SqlaTable.table_name == "table1")
).one_or_none()

@pytest.mark.usefixtures("create_dataset")
def test_model_view_rls_add_success(self):
self.login(username="admin")
test_dataset = self._get_test_dataset()
rv = self.client.post(
"/rowlevelsecurityfiltersmodelview/add",
data=dict(
name="rls1",
description="Some description",
filter_type="Regular",
tables=[test_dataset.id],
roles=[security_manager.find_role("Alpha").id],
group_key="group_key_1",
clause="client_id=1",
),
follow_redirects=True,
)
self.assertEqual(rv.status_code, 200)
rls1 = (
db.session.query(RowLevelSecurityFilter).filter_by(name="rls1")
).one_or_none()
assert rls1 is not None

# Revert data changes
db.session.delete(rls1)
db.session.commit()

@pytest.mark.usefixtures("create_dataset")
def test_model_view_rls_add_name_unique(self):
self.login(username="admin")
test_dataset = self._get_test_dataset()
rv = self.client.post(
"/rowlevelsecurityfiltersmodelview/add",
data=dict(
name="rls_entry1",
description="Some description",
filter_type="Regular",
tables=[test_dataset.id],
roles=[security_manager.find_role("Alpha").id],
group_key="group_key_1",
clause="client_id=1",
),
follow_redirects=True,
)
self.assertEqual(rv.status_code, 200)
data = rv.data.decode("utf-8")
assert "Already exists." in data

@pytest.mark.usefixtures("create_dataset")
def test_model_view_rls_add_tables_required(self):
self.login(username="admin")
rv = self.client.post(
"/rowlevelsecurityfiltersmodelview/add",
data=dict(
name="rls1",
description="Some description",
filter_type="Regular",
tables=[],
roles=[security_manager.find_role("Alpha").id],
group_key="group_key_1",
clause="client_id=1",
),
follow_redirects=True,
)
self.assertEqual(rv.status_code, 200)
data = rv.data.decode("utf-8")
assert "This field is required." in data

@pytest.mark.usefixtures("load_energy_table_with_slice")
def test_rls_filter_alters_energy_query(self):
g.user = self.get_user(username="alpha")
Expand Down
1 change: 1 addition & 0 deletions tests/unit_tests/sql_lab_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def test_sql_lab_insert_rls(

# now with RLS
rls = RowLevelSecurityFilter(
name="sqllab_rls1",
filter_type=RowLevelSecurityFilterType.REGULAR,
tables=[SqlaTable(database_id=1, schema=None, table_name="t")],
roles=[admin.roles[0]],
Expand Down

0 comments on commit 60eb109

Please sign in to comment.