Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass existing object references within access methods #14

Merged
merged 1 commit into from
Sep 8, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions awx/api/permissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def check_head_permissions(self, request, view, obj=None):

def check_get_permissions(self, request, view, obj=None):
if hasattr(view, 'parent_model'):
parent_obj = get_object_or_400(view.parent_model, pk=view.kwargs['pk'])
parent_obj = view.get_parent_object()
if not check_user_access(request.user, view.parent_model, 'read',
parent_obj):
return False
Expand All @@ -44,12 +44,12 @@ def check_get_permissions(self, request, view, obj=None):

def check_post_permissions(self, request, view, obj=None):
if hasattr(view, 'parent_model'):
parent_obj = get_object_or_400(view.parent_model, pk=view.kwargs['pk'])
parent_obj = view.get_parent_object()
if not check_user_access(request.user, view.parent_model, 'read',
parent_obj):
return False
if hasattr(view, 'parent_key'):
if not check_user_access(request.user, view.model, 'add', {view.parent_key: parent_obj.pk}):
if not check_user_access(request.user, view.model, 'add', {view.parent_key: parent_obj}):
return False
return True
elif getattr(view, 'is_job_start', False):
Expand Down
3 changes: 2 additions & 1 deletion awx/api/renderers.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def get_rendered_html_form(self, data, view, method, request):
obj = getattr(view, 'object', None)
if obj is None and hasattr(view, 'get_object') and hasattr(view, 'retrieve'):
try:
obj = view.get_object()
view.object = view.get_object()
obj = view.object
except Exception:
obj = None
with override_method(view, request, method) as request:
Expand Down
99 changes: 48 additions & 51 deletions awx/main/access.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@
from rest_framework.exceptions import ParseError, PermissionDenied, ValidationError

# AWX
from awx.main.utils import * # noqa
from awx.main.utils import (
get_object_or_400,
get_pk_from_dict,
to_python_boolean,
get_licenser,
)
from awx.main.models import * # noqa
from awx.main.models.unified_jobs import ACTIVE_STATES
from awx.main.models.mixins import ResourceMixin
Expand All @@ -36,6 +41,36 @@
}


def get_object_from_data(field, Model, data, obj=None):
"""
Utility method to obtain related object in data according to fallbacks:
- if data contains key with pointer to Django object, return that
- if contains integer, get object from database
- if this does not work, raise exception
"""
try:
raw_value = data[field]
except KeyError:
# Calling method needs to deal with non-existence of key
raise ParseError(_("Required related field %s for permission check." % field))

if isinstance(raw_value, Model):
return raw_value
elif raw_value is None:
return None
else:
try:
new_pk = int(raw_value)
# Avoid database query by comparing pk to model for similarity
if obj and new_pk == getattr(obj, '%s_id' % field, None):
return getattr(obj, field)
else:
# Get the new resource from the database
return get_object_or_400(Model, pk=new_pk)
except (TypeError, ValueError):
raise ParseError(_("Bad data found in related field %s." % field))


class StateConflict(ValidationError):
status_code = 409

Expand Down Expand Up @@ -205,24 +240,8 @@ def check_related(self, field, Model, data, role_field='admin_role',
# Use reference object's related fields, if given
new = getattr(data['reference_obj'], field)
elif data and field in data:
# Obtain the resource specified in `data`
raw_value = data[field]
if isinstance(raw_value, Model):
new = raw_value
elif raw_value is None:
new = None
else:
try:
new_pk = int(raw_value)
# Avoid database query by comparing pk to model for similarity
if obj and new_pk == getattr(obj, '%s_id' % field, None):
changed = False
else:
# Get the new resource from the database
new = get_object_or_400(Model, pk=new_pk)
except (TypeError, ValueError):
raise ParseError(_("Bad data found in related field %s." % field))
elif data is None or field not in data:
new = get_object_from_data(field, Model, data, obj=obj)
else:
changed = False

# Obtain existing related resource
Expand Down Expand Up @@ -940,17 +959,14 @@ def can_read(self, obj):
def can_add(self, data):
if not data: # So the browseable API will work
return True
user_pk = get_pk_from_dict(data, 'user')
if user_pk:
user_obj = get_object_or_400(User, pk=user_pk)
if data and data.get('user', None):
user_obj = get_object_from_data('user', User, data)
return check_user_access(self.user, User, 'change', user_obj, None)
team_pk = get_pk_from_dict(data, 'team')
if team_pk:
team_obj = get_object_or_400(Team, pk=team_pk)
if data and data.get('team', None):
team_obj = get_object_from_data('team', Team, data)
return check_user_access(self.user, Team, 'change', team_obj, None)
organization_pk = get_pk_from_dict(data, 'organization')
if organization_pk:
organization_obj = get_object_or_400(Organization, pk=organization_pk)
if data and data.get('organization', None):
organization_obj = get_object_from_data('organization', Organization, data)
return check_user_access(self.user, Organization, 'change', organization_obj, None)
return False

Expand Down Expand Up @@ -1173,9 +1189,8 @@ def get_value(Class, field):
if reference_obj:
return getattr(reference_obj, field, None)
else:
pk = get_pk_from_dict(data, field)
if pk:
return get_object_or_400(Class, pk=pk)
if data and data.get(field, None):
return get_object_from_data(field, Class, data)
else:
return None

Expand Down Expand Up @@ -1261,23 +1276,6 @@ def changes_are_non_sensitive(self, obj, data):
return False
return True

def can_update_sensitive_fields(self, obj, data):
project_id = data.get('project', obj.project.id if obj.project else None)
inventory_id = data.get('inventory', obj.inventory.id if obj.inventory else None)
credential_id = data.get('credential', obj.credential.id if obj.credential else None)
vault_credential_id = data.get('credential', obj.vault_credential.id if obj.vault_credential else None)

if project_id and self.user not in Project.objects.get(pk=project_id).use_role:
return False
if inventory_id and self.user not in Inventory.objects.get(pk=inventory_id).use_role:
return False
if credential_id and self.user not in Credential.objects.get(pk=credential_id).use_role:
return False
if vault_credential_id and self.user not in Credential.objects.get(pk=vault_credential_id).use_role:
return False

return True

def can_delete(self, obj):
is_delete_allowed = self.user.is_superuser or self.user in obj.admin_role
if not is_delete_allowed:
Expand Down Expand Up @@ -1387,9 +1385,8 @@ def can_add(self, data, validate_license=True):
add_data = dict(data.items())

# If a job template is provided, the user should have read access to it.
job_template_pk = get_pk_from_dict(data, 'job_template')
if job_template_pk:
job_template = get_object_or_400(JobTemplate, pk=job_template_pk)
if data and data.get('job_template', None):
job_template = get_object_from_data('job_template', JobTemplate, data)
add_data.setdefault('inventory', job_template.inventory.pk)
add_data.setdefault('project', job_template.project.pk)
add_data.setdefault('job_type', job_template.job_type)
Expand Down
24 changes: 12 additions & 12 deletions awx/main/tests/unit/api/test_generics.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,28 +242,28 @@ def mock_request(self):
), method='GET')


def mock_view(self):
def mock_view(self, parent=None):
view = ResourceAccessList()
view.parent_model = Organization
view.kwargs = {'pk': 4}
if parent:
view.get_parent_object = lambda: parent
return view


def test_parent_access_check_failed(self, mocker, mock_organization):
with mocker.patch('awx.api.permissions.get_object_or_400', return_value=mock_organization):
mock_access = mocker.MagicMock(__name__='for logger', return_value=False)
with mocker.patch('awx.main.access.BaseAccess.can_read', mock_access):
with pytest.raises(PermissionDenied):
self.mock_view().check_permissions(self.mock_request())
mock_access.assert_called_once_with(mock_organization)
mock_access = mocker.MagicMock(__name__='for logger', return_value=False)
with mocker.patch('awx.main.access.BaseAccess.can_read', mock_access):
with pytest.raises(PermissionDenied):
self.mock_view(parent=mock_organization).check_permissions(self.mock_request())
mock_access.assert_called_once_with(mock_organization)


def test_parent_access_check_worked(self, mocker, mock_organization):
with mocker.patch('awx.api.permissions.get_object_or_400', return_value=mock_organization):
mock_access = mocker.MagicMock(__name__='for logger', return_value=True)
with mocker.patch('awx.main.access.BaseAccess.can_read', mock_access):
self.mock_view().check_permissions(self.mock_request())
mock_access.assert_called_once_with(mock_organization)
mock_access = mocker.MagicMock(__name__='for logger', return_value=True)
with mocker.patch('awx.main.access.BaseAccess.can_read', mock_access):
self.mock_view(parent=mock_organization).check_permissions(self.mock_request())
mock_access.assert_called_once_with(mock_organization)


def test_related_search_reverse_FK_field():
Expand Down
5 changes: 4 additions & 1 deletion awx/main/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,7 +724,10 @@ def get_pk_from_dict(_dict, key):
Helper for obtaining a pk from user data dict or None if not present.
'''
try:
return int(_dict[key])
val = _dict[key]
if isinstance(val, object) and hasattr(val, 'id'):
return val.id # return id if given model object
return int(val)
except (TypeError, KeyError, ValueError):
return None

Expand Down