Skip to content

Commit

Permalink
Implement user roles.
Browse files Browse the repository at this point in the history
  • Loading branch information
FrankNagel committed Nov 23, 2015
1 parent fb8093c commit 8f5ba8d
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 12 deletions.
32 changes: 32 additions & 0 deletions alembic/versions/1a0ea61b9a91_add_user_roles.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""Add user roles
Revision ID: 1a0ea61b9a91
Revises: 44f16b57d26f
Create Date: 2015-11-23 12:54:42.724658
"""

# revision identifiers, used by Alembic.
revision = '1a0ea61b9a91'
down_revision = '44f16b57d26f'

from alembic import op
import sqlalchemy as sa
from sqlalchemy import DDL

from digital_ale.models import ArrayOfEnum


def upgrade():
#alembic doesn't create Enum if nested in Array
op.execute(DDL("CREATE TYPE user_role as ENUM('admin', 'editor')"))
### commands auto generated by Alembic - please adjust! ###
op.add_column('tbl_user', sa.Column('roles', ArrayOfEnum(sa.Enum('admin', 'editor', name='user_role')), nullable=True))
### end Alembic commands ###


def downgrade():
op.execute(DDL('Drop type user_role'))
### commands auto generated by Alembic - please adjust! ###
op.drop_column('tbl_user', 'roles')
### end Alembic commands ###
3 changes: 2 additions & 1 deletion digital_ale/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
DBSession,
Base,
RootFactory,
get_user_roles
)


Expand All @@ -24,7 +25,7 @@ def main(global_config, **settings):
reissue_time=settings.get('session.reissue_time', 360)
)

authn_policy = SessionAuthenticationPolicy()
authn_policy = SessionAuthenticationPolicy(callback=get_user_roles)
authz_policy = ACLAuthorizationPolicy()

config = Configurator(
Expand Down
41 changes: 40 additions & 1 deletion digital_ale/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
Enum,
)

from sqlalchemy.dialects.postgresql import ARRAY

from pyramid.security import (
Everyone,
Authenticated,
Expand All @@ -55,6 +57,29 @@ def hash_password(password):
DBSession = scoped_session(sessionmaker(extension=ZopeTransactionExtension()))
Base = declarative_base()

class ArrayOfEnum(ARRAY):
"""Helper class to support arrays of enums in PostgreSQL"""
def bind_expression(self, bindvalue):
return sa.cast(bindvalue, self)

def result_processor(self, dialect, coltype):
super_rp = super(ArrayOfEnum, self).result_processor(
dialect, coltype)

def handle_raw_string(value):
inner = re.match(r"^{(.*)}$", value).group(1)
return inner.split(",")

def process(value):
if value is None:
return None
return super_rp(handle_raw_string(value))
return process


RoleClass = namedtuple('Role', 'admin editor')
Role = RoleClass('admin', 'editor')


class User(Base):
"""
Expand All @@ -65,6 +90,7 @@ class User(Base):
login_name = Column(Unicode(50), unique=True)
display_name = Column(Unicode(50))
email = Column(Unicode(80))
roles = Column(ArrayOfEnum(Enum(*Role, name="user_role")))

_password = Column('password', Unicode(60))

Expand Down Expand Up @@ -95,6 +121,16 @@ def check_password(cls, username, password):
return crypt.check(user.password, password)


def get_user_roles(userid, request):
user = User.get_by_username(userid)
if user is None:
return None
elif user.roles is None:
return []
else:
return ['role:'+r for r in user.roles]


class Scan(Base):
__tablename__ = 'tbl_scan'
id = Column(Integer, primary_key=True)
Expand Down Expand Up @@ -405,7 +441,10 @@ def get_by_scan_prefix(cls, prefix):
class RootFactory(object):
__acl__ = [
(Allow, Everyone, 'view'),
(Allow, Authenticated, 'post')
(Allow, Authenticated, 'post'),
(Allow, 'role:editor', 'edit_sheet'),
(Allow, 'role:admin', 'bulk_extract')

]

def __init__(self, request):
Expand Down
28 changes: 18 additions & 10 deletions digital_ale/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
HTTPFound,
HTTPNotFound,
HTTPUnauthorized,
HTTPForbidden,
)

from pyramid.security import (
Expand Down Expand Up @@ -378,13 +379,13 @@ def place_candidate_delete(request):
return dict(status='OK')


@view_config(route_name='sheet_edit', renderer='json', request_method='POST')
@view_config(route_name='sheet_edit', renderer='json', request_method='POST', permission='edit_sheet')
def sheet_edit(request):
username = authenticated_userid(request)
user = User.get_by_username(username)
if user is None:
request.response.status_code = 401
return dict(status=401)
#should never happen with active authorization policy
raise HTTPForbidden()
concept_id = request.matchdict['concept_id']
scan_name = request.matchdict['scan_name']
message = ''
Expand All @@ -399,8 +400,8 @@ def sheet_edit(request):

new_status = request.POST.get('status', '')
if new_status not in SheetEntryState:
request.response.status_code = 401
return dict(status=401, reason="Invalid status code")
request.response.status_code = 400
return dict(status=400, reason="Invalid status code")
sheetEntry.status = new_status
sheetEntry.editor_fkey = user.id
sheetEntry.data = request.POST.get('data', '')
Expand All @@ -412,12 +413,19 @@ def sheet_edit(request):
return dict(status='OK')


@view_config(route_name='extract_pronounciation', renderer='json', request_method='POST')
@view_config(route_name='extract_pronounciation', renderer='json', request_method='POST', permission='bulk_extract')
def extract_pronounciation(request):
sheets = SheetEntry.extract_pronounciation(request.POST.get('all', '').lower() in ('true', 'yes', '1'))
return dict(status='OK', num_sheets=len(sheets))


@view_config(context=HTTPForbidden, route_name='extract_pronounciation', renderer='json', request_method='POST')
@view_config(context=HTTPForbidden, route_name='sheet_edit', renderer='json', request_method='POST')
def json_authorization_error(request):
username = authenticated_userid(request)
user = User.get_by_username(username)
if user is None:
if username is None:
request.response.status_code = 401
return dict(status=401)
sheets = SheetEntry.extract_pronounciation(request.POST.get('all', '').lower() in ('true', 'yes', '1'))
return dict(status='OK', num_sheets=len(sheets))
else:
request.response.status_code = 403
return dict(status=403)

0 comments on commit 8f5ba8d

Please sign in to comment.