Skip to content

Commit

Permalink
Migrate permissions
Browse files Browse the repository at this point in the history
  • Loading branch information
Bogdan Kyryliuk committed Mar 21, 2017
1 parent 42ff28b commit ae57989
Show file tree
Hide file tree
Showing 9 changed files with 269 additions and 14 deletions.
19 changes: 16 additions & 3 deletions superset/connectors/connector_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,29 @@ def get_all_datasources(cls, session):
return datasources

@classmethod
def get_datasource_by_name(cls, session, datasource_type, datasource_name,
schema, database_name):
def get_datasources_by_name(
cls, session, datasource_type, datasource_name, schema,
database_name
):
datasource_class = ConnectorRegistry.sources[datasource_type]
datasources = session.query(datasource_class).all()

# Filter datasoures that don't have database.
db_ds = [d for d in datasources if d.database and
d.database.name == database_name and
d.name == datasource_name and schema == schema]
return db_ds[0]
return db_ds


@classmethod
def get_datasource_by_name(
cls, session, datasource_type, datasource_name, schema,
database_name
):
return cls.get_datasources_by_name(
session, datasource_type, datasource_name, schema,
database_name
)[0]

@classmethod
def query_datasources_by_permissions(cls, session, database, permissions):
Expand Down
12 changes: 9 additions & 3 deletions superset/connectors/druid/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class DruidCluster(Model, AuditMixinNullable):
broker_endpoint = Column(String(255), default='druid/v2')
metadata_last_refreshed = Column(DateTime)
cache_timeout = Column(Integer)
perm = Column(String(1000))

def __repr__(self):
return self.verbose_name if self.verbose_name else self.cluster_name
Expand Down Expand Up @@ -103,9 +104,8 @@ def refresh_datasources(self, datasource_name=None, merge_flag=False):
if not datasource_name or datasource_name == datasource:
DruidDatasource.sync_to_db(datasource, self, merge_flag)

@property
def perm(self):
return self.cluster_name
def get_perm(self):
return '{}.{}'.format(self.type, self.unique_name)

@property
def name(self):
Expand Down Expand Up @@ -253,6 +253,11 @@ def lookup_obj(lookup_column):
return import_util.import_simple_obj(db.session, i_column, lookup_obj)


sa.event.listen(DruidCluster, 'after_insert', set_perm)
sa.event.listen(DruidCluster, 'after_update', set_perm)



class DruidMetric(Model, BaseMetric):

"""ORM object referencing Druid metrics for a datasource"""
Expand Down Expand Up @@ -379,6 +384,7 @@ def schema_perm(self):

def get_perm(self):
cluster = self.cluster
logging.info('looking for the {}'.format(self.cluster_name))
if not cluster:
cluster = db.session.query(DruidCluster).filter_by(
cluster_name=self.cluster_name).one()
Expand Down
6 changes: 5 additions & 1 deletion superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,11 @@ def schema_perm(self):
return utils.get_schema_perm(self.database, self.schema)

def get_perm(self):
return "{}.{}".format(self.database.perm, self.name)
database = self.database
if not database:
database = db.session.query(Database).filter_by(
id=self.database_id).one()
return "{}.{}".format(database.perm, self.name)

@property
def name(self):
Expand Down
220 changes: 220 additions & 0 deletions superset/migrations/versions/e8c16094b97b_rename_permissions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
"""Renames permissions
Revision ID: e8c16094b97b
Revises: db527d8c4c78
Create Date: 2017-03-20 10:07:20.926604
"""

# revision identifiers, used by Alembic.
revision = 'e8c16094b97b'
down_revision = 'db527d8c4c78'

import logging
import sqlalchemy as sa

from alembic import op
from superset import db
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import backref, relationship, sessionmaker
from sqlalchemy import (
Column, Integer, String, ForeignKey, Sequence)

Base = declarative_base()


class Database(Base):
"""Declarative class to do query in upgrade"""
__tablename__ = 'dbs'
type = "table"
id = Column(Integer, primary_key=True)
database_name = Column(String(250), unique=True)
perm = Column(String(1000))

def get_perm(self):
return '{}.{}'.format(self.type, self.database_name)

def get_old_perm(self):
return "[{obj.database_name}].(id:{obj.id})".format(obj=self)


class DruidCluster(Base):
__tablename__ = 'clusters'
type = "druid"
id = Column(Integer, primary_key=True)
cluster_name = Column(String(250), unique=True)
perm = Column(String(1000))

def get_perm(self):
return '{}.{}'.format(self.type, self.cluster_name)

def get_old_perm(self):
return "[{obj.cluster_name}].(id:{obj.id})".format(obj=self)


class SqlaTable(Base):
"""Declarative class to do query in upgrade"""
__tablename__ = 'tables'
type = "table"
id = Column(Integer, primary_key=True)
table_name = Column(String(250))
schema = Column(String(255))
database_id = Column(Integer, ForeignKey('dbs.id'), nullable=False)
database = relationship(
'Database',
backref=backref('tables', cascade='all, delete-orphan'),
foreign_keys=[database_id])
perm = Column(String(1000))

def get_perm(self):
return "{}.{}".format(self.database.perm, self.name)

def get_old_perm(self):
return "[{obj.database}].[{obj.table_name}](id:{obj.id})".format(
obj=self)

def get_schema_perm(self):
"""Returns schema permission if present, database one otherwise."""
if self.schema:
return "{}.{}".format(self.database.perm, self.schema)

def get_old_schema_perm(self):
"""Returns schema permission if present, database one otherwise."""
if self.schema:
return "[{}].[{}]".format(self.database, self.schema)


class DruidDatasource(Base):
"""Declarative class to do query in upgrade"""
type = "druid"
__tablename__ = 'datasources'
id = Column(Integer, primary_key=True)
datasource_name = Column(String(255), unique=True)
cluster_name = Column(
String(250), ForeignKey('clusters.cluster_name'))
cluster = relationship(
'DruidCluster', backref='datasources', foreign_keys=[cluster_name])
perm = Column(String(1000))

def get_perm(self):
return "{}.{}".format(self.cluster.perm, self.name)

def get_old_perm(self):
return (
"[{obj.cluster_name}].[{obj.datasource_name}](id:{obj.id})".format(
obj=self))


class ViewMenu(Base):
__tablename__ = 'ab_view_menu'
id = Column(Integer, Sequence('ab_view_menu_id_seq'), primary_key=True)
name = Column(String(100), unique=True, nullable=False)


class PermissionView(Base):
__tablename__ = 'ab_permission_view'
id = Column(Integer, Sequence('ab_permission_view_id_seq'),
primary_key=True)
view_menu_id = Column(Integer, ForeignKey('ab_view_menu.id'))
view_menu = relationship("ViewMenu")


def update_perms(
obj_class,
old_perm_attr='perm',
new_perm_attr='get_perm',
db_attr='perm',
):
session = db.session
for obj in session.query(obj_class).all():
old_perm = getattr(obj, old_perm_attr)
logging.info('renaming {} permission'.format(old_perm))
old_view_menu = session.query(ViewMenu).filter_by(name=old_perm).first()
new_perm = getattr(obj, new_perm_attr)
if old_view_menu:
new_view_menu = session.query(ViewMenu).filter_by(
name=new_perm).first()
if new_view_menu:
# View menu already exists, attach permission view menues to the
# found view menu. Impossible to reverse.
pvms = session.query(PermissionView).filter_by(
view_menu_id=old_view_menu.id).all()
for pvm in pvms:
pvm.view_menu_id = new_view_menu.id
session.flush()
else:
# Rename the view menu name
old_view_menu.name = new_perm
session.flush()
# Persist update perm value.
if db_attr:
setattr(obj, db_attr, new_perm)
session.flush()
session.commit()


def downgrade_perms(
obj_class,
old_perm_attr='get_old_perm',
new_perm_attr='get_perm',
db_attr='perm',
):
session = db.session
for obj in session.query(obj_class).all():
new_perm = getattr(obj, new_perm_attr)
new_view_menu = session.query(ViewMenu).filter_by(name=new_perm).one()

old_perm = getattr(obj, old_perm_attr)
old_view_menu = session.query(ViewMenu).filter_by(
name=old_perm).first()

# Should not exist.
if old_view_menu:
# View menu already exists, attach permission view menues to the
# found view menu. Impossible to reverse.
pvms = session.query(PermissionView).filter_by(
view_menu_id=new_perm.id).all()
for pvm in pvms:
pvm.view_menu_id = old_view_menu.id
session.flush()
else:
# Rename the view menu name
new_view_menu.name = old_perm
session.flush()

# Persist update perm value if not processing schema.
if db_attr:
setattr(obj, db_attr, new_perm)
session.flush()
session.commit()


def upgrade():
op.add_column('clusters', sa.Column('perm', sa.Text(), nullable=True))
bind = op.get_bind()
Session = sessionmaker()
session = Session(bind=bind)
session.commit()

update_perms(Database)
update_perms(SqlaTable)
update_perms(
SqlaTable, old_perm_attr='get_old_schema_perm',
new_perm_attr='get_schema_perm', db_attr=None,
)

update_perms(DruidCluster)
update_perms(DruidDatasource)


def downgrade():
downgrade_perms(Database)
downgrade_perms(SqlaTable)
downgrade_perms(
SqlaTable, old_perm_attr='get_old_schema_perm',
new_perm_attr='get_schema_perm', db_attr=None,
)
downgrade_perms(DruidCluster)
downgrade_perms(DruidDatasource)
op.drop_column('clusters', 'perm')

6 changes: 3 additions & 3 deletions superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ class Slice(Model, AuditMixinNullable, ImportMixin):
params = Column(Text)
description = Column(Text)
cache_timeout = Column(Integer)
# TODO: remove perm as slice datasource may be switched.
perm = Column(String(1000))
owners = relationship("User", secondary=slice_user)

Expand Down Expand Up @@ -128,8 +129,7 @@ def get_datasource(self):

@renders('datasource_name')
def datasource_link(self):
datasource = self.datasource
if datasource:
if self.datasource:
return self.datasource.link

@property
Expand Down Expand Up @@ -714,7 +714,7 @@ def sql_url(self):
return '/superset/sql/{}/'.format(self.id)

def get_perm(self):
return self.unique_name
return '{}.{}'.format(self.type, self.unique_name)

sqla.event.listen(Database, 'after_insert', set_perm)
sqla.event.listen(Database, 'after_update', set_perm)
Expand Down
5 changes: 3 additions & 2 deletions superset/views/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,9 @@ def datasource_access_by_name(
sm, 'schema_access', schema_perm, g.user):
return True

datasources = ConnectorRegistry.query_datasources_by_name(
db.session, database, datasource_name, schema=schema)
# Checking among duplicated datasources.
datasources = ConnectorRegistry.get_datasources_by_name(
db.session, database.type, datasource_name, schema, database.name)
for datasource in datasources:
if self.can_access("datasource_access", datasource.perm):
return True
Expand Down
9 changes: 9 additions & 0 deletions tests/druid_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,15 @@ def test_client(self, PyDruid):
metadata_last_refreshed=datetime.now())

db.session.add(cluster)
db.session.commit()

druid_datasource = DruidDatasource(
datasource_name='druid_test',
cluster_name='test_cluster'
)
db.session.add(druid_datasource)
db.session.commit()

cluster.get_datasources = Mock(return_value=['druid_test'])
cluster.get_druid_version = Mock(return_value='0.9.1')
cluster.refresh_datasources()
Expand Down
4 changes: 3 additions & 1 deletion tests/import_export_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,11 @@ def create_dashboard(self, title, id=0, slcs=[]):

def create_table(
self, name, schema='', id=0, cols_names=[], metric_names=[]):
params = {'remote_id': id, 'database_name': 'main'}
main = self.get_main_database(db.session)
params = {'remote_id': id, 'database_name': main.database_name}
table = SqlaTable(
id=id,
database_id=main.id,
schema=schema,
table_name=name,
params=json.dumps(params)
Expand Down
2 changes: 1 addition & 1 deletion tests/sqllab_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_sql_json_has_access(self):
main_db_permission_view = (
db.session.query(ab_models.PermissionView)
.join(ab_models.ViewMenu)
.filter(ab_models.ViewMenu.name == '[main].(id:1)')
.filter(ab_models.ViewMenu.name == 'table.main')
.first()
)
astronaut = sm.add_role("Astronaut")
Expand Down

0 comments on commit ae57989

Please sign in to comment.