Skip to content

Commit

Permalink
refactor: Using self.get_session in security manager (#10146)
Browse files Browse the repository at this point in the history
Co-authored-by: John Bodley <john.bodley@airbnb.com>
  • Loading branch information
john-bodley and john-bodley committed Jul 4, 2020
1 parent b181e48 commit 33584a8
Showing 1 changed file with 16 additions and 21 deletions.
37 changes: 16 additions & 21 deletions superset/security/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,15 +401,15 @@ def get_public_role(self) -> Optional[Any]: # Optional[self.role_model]
if not conf.get("PUBLIC_ROLE_LIKE_GAMMA", False):
return None

from superset import db

return db.session.query(self.role_model).filter_by(name="Public").first()
return (
self.get_session.query(self.role_model)
.filter_by(name="Public")
.one_or_none()
)

def user_view_menu_names(self, permission_name: str) -> Set[str]:
from superset import db

base_query = (
db.session.query(self.viewmenu_model.name)
self.get_session.query(self.viewmenu_model.name)
.join(self.permissionview_model)
.join(self.permission_model)
.join(assoc_permissionview_role)
Expand Down Expand Up @@ -450,7 +450,6 @@ def get_schemas_accessible_by_user(
:returns: The list of accessible SQL schemas
"""

from superset import db
from superset.connectors.sqla.models import SqlaTable

if hierarchical and self.can_access_database(database):
Expand All @@ -467,7 +466,7 @@ def get_schemas_accessible_by_user(
perms = self.user_view_menu_names("datasource_access")
if perms:
tables = (
db.session.query(SqlaTable.schema)
self.get_session.query(SqlaTable.schema)
.filter(SqlaTable.database_id == database.id)
.filter(SqlaTable.schema.isnot(None))
.filter(SqlaTable.schema != "")
Expand All @@ -493,8 +492,6 @@ def get_datasources_accessible_by_user( # pylint: disable=invalid-name
:returns: The list of accessible SQL tables w/ schema
"""

from superset import db

if self.can_access_database(database):
return datasource_names

Expand All @@ -506,7 +503,7 @@ def get_datasources_accessible_by_user( # pylint: disable=invalid-name
user_perms = self.user_view_menu_names("datasource_access")
schema_perms = self.user_view_menu_names("schema_access")
user_datasources = ConnectorRegistry.query_datasources_by_permissions(
db.session, database, user_perms, schema_perms
self.get_session, database, user_perms, schema_perms
)
if schema:
names = {d.table_name for d in user_datasources if d.schema == schema}
Expand Down Expand Up @@ -552,7 +549,6 @@ def create_missing_perms(self) -> None:
Creates missing FAB permissions for datasources, schemas and metrics.
"""

from superset import db
from superset.connectors.base.models import BaseMetric
from superset.models import core as models

Expand All @@ -568,20 +564,20 @@ def merge_pv(view_menu: str, perm: str) -> None:
self.add_permission_view_menu(view_menu, perm)

logger.info("Creating missing datasource permissions.")
datasources = ConnectorRegistry.get_all_datasources(db.session)
datasources = ConnectorRegistry.get_all_datasources(self.get_session)
for datasource in datasources:
merge_pv("datasource_access", datasource.get_perm())
merge_pv("schema_access", datasource.get_schema_perm())

logger.info("Creating missing database permissions.")
databases = db.session.query(models.Database).all()
databases = self.get_session.query(models.Database).all()
for database in databases:
merge_pv("database_access", database.perm)

logger.info("Creating missing metrics permissions")
metrics: List[BaseMetric] = []
for datasource_class in ConnectorRegistry.sources.values():
metrics += list(db.session.query(datasource_class.metric_class).all())
metrics += list(self.get_session.query(datasource_class.metric_class).all())

def clean_perms(self) -> None:
"""
Expand Down Expand Up @@ -786,7 +782,7 @@ def set_perm( # pylint: disable=no-self-use,unused-argument
"""
Set the datasource permissions.
:param mapper: The table mappper
:param mapper: The table mapper
:param connection: The DB-API connection
:param target: The mapped instance being persisted
"""
Expand Down Expand Up @@ -943,30 +939,29 @@ def get_rls_filters( # pylint: disable=no-self-use
:returns: A list of filters
"""
if hasattr(g, "user") and hasattr(g.user, "id"):
from superset import db
from superset.connectors.sqla.models import (
RLSFilterRoles,
RLSFilterTables,
RowLevelSecurityFilter,
)

user_roles = (
db.session.query(assoc_user_role.c.role_id)
self.get_session.query(assoc_user_role.c.role_id)
.filter(assoc_user_role.c.user_id == g.user.id)
.subquery()
)
filter_roles = (
db.session.query(RLSFilterRoles.c.rls_filter_id)
self.get_session.query(RLSFilterRoles.c.rls_filter_id)
.filter(RLSFilterRoles.c.role_id.in_(user_roles))
.subquery()
)
filter_tables = (
db.session.query(RLSFilterTables.c.rls_filter_id)
self.get_session.query(RLSFilterTables.c.rls_filter_id)
.filter(RLSFilterTables.c.table_id == table.id)
.subquery()
)
query = (
db.session.query(
self.get_session.query(
RowLevelSecurityFilter.id, RowLevelSecurityFilter.clause
)
.filter(RowLevelSecurityFilter.id.in_(filter_tables))
Expand Down

0 comments on commit 33584a8

Please sign in to comment.