Skip to content

Commit

Permalink
Merge pull request #14 from AlanCoding/rel_obj_conservation
Browse files Browse the repository at this point in the history
Pass existing object references within access methods
  • Loading branch information
matburt committed Sep 8, 2017
2 parents c6ae6b8 + 4194068 commit 2fb9b6c
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 68 deletions.
6 changes: 3 additions & 3 deletions awx/api/permissions.py
Expand Up @@ -34,7 +34,7 @@ def check_head_permissions(self, request, view, obj=None):

def check_get_permissions(self, request, view, obj=None):
if hasattr(view, 'parent_model'):
parent_obj = get_object_or_400(view.parent_model, pk=view.kwargs['pk'])
parent_obj = view.get_parent_object()
if not check_user_access(request.user, view.parent_model, 'read',
parent_obj):
return False
Expand All @@ -44,12 +44,12 @@ def check_get_permissions(self, request, view, obj=None):

def check_post_permissions(self, request, view, obj=None):
if hasattr(view, 'parent_model'):
parent_obj = get_object_or_400(view.parent_model, pk=view.kwargs['pk'])
parent_obj = view.get_parent_object()
if not check_user_access(request.user, view.parent_model, 'read',
parent_obj):
return False
if hasattr(view, 'parent_key'):
if not check_user_access(request.user, view.model, 'add', {view.parent_key: parent_obj.pk}):
if not check_user_access(request.user, view.model, 'add', {view.parent_key: parent_obj}):
return False
return True
elif getattr(view, 'is_job_start', False):
Expand Down
3 changes: 2 additions & 1 deletion awx/api/renderers.py
Expand Up @@ -48,7 +48,8 @@ def get_rendered_html_form(self, data, view, method, request):
obj = getattr(view, 'object', None)
if obj is None and hasattr(view, 'get_object') and hasattr(view, 'retrieve'):
try:
obj = view.get_object()
view.object = view.get_object()
obj = view.object
except Exception:
obj = None
with override_method(view, request, method) as request:
Expand Down
99 changes: 48 additions & 51 deletions awx/main/access.py
Expand Up @@ -17,7 +17,12 @@
from rest_framework.exceptions import ParseError, PermissionDenied, ValidationError

# AWX
from awx.main.utils import * # noqa
from awx.main.utils import (
get_object_or_400,
get_pk_from_dict,
to_python_boolean,
get_licenser,
)
from awx.main.models import * # noqa
from awx.main.models.unified_jobs import ACTIVE_STATES
from awx.main.models.mixins import ResourceMixin
Expand All @@ -36,6 +41,36 @@
}


def get_object_from_data(field, Model, data, obj=None):
"""
Utility method to obtain related object in data according to fallbacks:
- if data contains key with pointer to Django object, return that
- if contains integer, get object from database
- if this does not work, raise exception
"""
try:
raw_value = data[field]
except KeyError:
# Calling method needs to deal with non-existence of key
raise ParseError(_("Required related field %s for permission check." % field))

if isinstance(raw_value, Model):
return raw_value
elif raw_value is None:
return None
else:
try:
new_pk = int(raw_value)
# Avoid database query by comparing pk to model for similarity
if obj and new_pk == getattr(obj, '%s_id' % field, None):
return getattr(obj, field)
else:
# Get the new resource from the database
return get_object_or_400(Model, pk=new_pk)
except (TypeError, ValueError):
raise ParseError(_("Bad data found in related field %s." % field))


class StateConflict(ValidationError):
status_code = 409

Expand Down Expand Up @@ -205,24 +240,8 @@ def check_related(self, field, Model, data, role_field='admin_role',
# Use reference object's related fields, if given
new = getattr(data['reference_obj'], field)
elif data and field in data:
# Obtain the resource specified in `data`
raw_value = data[field]
if isinstance(raw_value, Model):
new = raw_value
elif raw_value is None:
new = None
else:
try:
new_pk = int(raw_value)
# Avoid database query by comparing pk to model for similarity
if obj and new_pk == getattr(obj, '%s_id' % field, None):
changed = False
else:
# Get the new resource from the database
new = get_object_or_400(Model, pk=new_pk)
except (TypeError, ValueError):
raise ParseError(_("Bad data found in related field %s." % field))
elif data is None or field not in data:
new = get_object_from_data(field, Model, data, obj=obj)
else:
changed = False

# Obtain existing related resource
Expand Down Expand Up @@ -940,17 +959,14 @@ def can_read(self, obj):
def can_add(self, data):
if not data: # So the browseable API will work
return True
user_pk = get_pk_from_dict(data, 'user')
if user_pk:
user_obj = get_object_or_400(User, pk=user_pk)
if data and data.get('user', None):
user_obj = get_object_from_data('user', User, data)
return check_user_access(self.user, User, 'change', user_obj, None)
team_pk = get_pk_from_dict(data, 'team')
if team_pk:
team_obj = get_object_or_400(Team, pk=team_pk)
if data and data.get('team', None):
team_obj = get_object_from_data('team', Team, data)
return check_user_access(self.user, Team, 'change', team_obj, None)
organization_pk = get_pk_from_dict(data, 'organization')
if organization_pk:
organization_obj = get_object_or_400(Organization, pk=organization_pk)
if data and data.get('organization', None):
organization_obj = get_object_from_data('organization', Organization, data)
return check_user_access(self.user, Organization, 'change', organization_obj, None)
return False

Expand Down Expand Up @@ -1173,9 +1189,8 @@ def get_value(Class, field):
if reference_obj:
return getattr(reference_obj, field, None)
else:
pk = get_pk_from_dict(data, field)
if pk:
return get_object_or_400(Class, pk=pk)
if data and data.get(field, None):
return get_object_from_data(field, Class, data)
else:
return None

Expand Down Expand Up @@ -1261,23 +1276,6 @@ def changes_are_non_sensitive(self, obj, data):
return False
return True

def can_update_sensitive_fields(self, obj, data):
project_id = data.get('project', obj.project.id if obj.project else None)
inventory_id = data.get('inventory', obj.inventory.id if obj.inventory else None)
credential_id = data.get('credential', obj.credential.id if obj.credential else None)
vault_credential_id = data.get('credential', obj.vault_credential.id if obj.vault_credential else None)

if project_id and self.user not in Project.objects.get(pk=project_id).use_role:
return False
if inventory_id and self.user not in Inventory.objects.get(pk=inventory_id).use_role:
return False
if credential_id and self.user not in Credential.objects.get(pk=credential_id).use_role:
return False
if vault_credential_id and self.user not in Credential.objects.get(pk=vault_credential_id).use_role:
return False

return True

def can_delete(self, obj):
is_delete_allowed = self.user.is_superuser or self.user in obj.admin_role
if not is_delete_allowed:
Expand Down Expand Up @@ -1387,9 +1385,8 @@ def can_add(self, data, validate_license=True):
add_data = dict(data.items())

# If a job template is provided, the user should have read access to it.
job_template_pk = get_pk_from_dict(data, 'job_template')
if job_template_pk:
job_template = get_object_or_400(JobTemplate, pk=job_template_pk)
if data and data.get('job_template', None):
job_template = get_object_from_data('job_template', JobTemplate, data)
add_data.setdefault('inventory', job_template.inventory.pk)
add_data.setdefault('project', job_template.project.pk)
add_data.setdefault('job_type', job_template.job_type)
Expand Down
24 changes: 12 additions & 12 deletions awx/main/tests/unit/api/test_generics.py
Expand Up @@ -242,28 +242,28 @@ def mock_request(self):
), method='GET')


def mock_view(self):
def mock_view(self, parent=None):
view = ResourceAccessList()
view.parent_model = Organization
view.kwargs = {'pk': 4}
if parent:
view.get_parent_object = lambda: parent
return view


def test_parent_access_check_failed(self, mocker, mock_organization):
with mocker.patch('awx.api.permissions.get_object_or_400', return_value=mock_organization):
mock_access = mocker.MagicMock(__name__='for logger', return_value=False)
with mocker.patch('awx.main.access.BaseAccess.can_read', mock_access):
with pytest.raises(PermissionDenied):
self.mock_view().check_permissions(self.mock_request())
mock_access.assert_called_once_with(mock_organization)
mock_access = mocker.MagicMock(__name__='for logger', return_value=False)
with mocker.patch('awx.main.access.BaseAccess.can_read', mock_access):
with pytest.raises(PermissionDenied):
self.mock_view(parent=mock_organization).check_permissions(self.mock_request())
mock_access.assert_called_once_with(mock_organization)


def test_parent_access_check_worked(self, mocker, mock_organization):
with mocker.patch('awx.api.permissions.get_object_or_400', return_value=mock_organization):
mock_access = mocker.MagicMock(__name__='for logger', return_value=True)
with mocker.patch('awx.main.access.BaseAccess.can_read', mock_access):
self.mock_view().check_permissions(self.mock_request())
mock_access.assert_called_once_with(mock_organization)
mock_access = mocker.MagicMock(__name__='for logger', return_value=True)
with mocker.patch('awx.main.access.BaseAccess.can_read', mock_access):
self.mock_view(parent=mock_organization).check_permissions(self.mock_request())
mock_access.assert_called_once_with(mock_organization)


def test_related_search_reverse_FK_field():
Expand Down
5 changes: 4 additions & 1 deletion awx/main/utils/common.py
Expand Up @@ -724,7 +724,10 @@ def get_pk_from_dict(_dict, key):
Helper for obtaining a pk from user data dict or None if not present.
'''
try:
return int(_dict[key])
val = _dict[key]
if isinstance(val, object) and hasattr(val, 'id'):
return val.id # return id if given model object
return int(val)
except (TypeError, KeyError, ValueError):
return None

Expand Down

0 comments on commit 2fb9b6c

Please sign in to comment.