From 41940687f1648ff1865d5ea9b677ddacbb42dbfe Mon Sep 17 00:00:00 2001 From: AlanCoding Date: Wed, 30 Aug 2017 16:05:02 -0400 Subject: [PATCH] Pass existing object references within access methods This avoids re-loading objects from the database in our chain of permission checking, wherever possible. access.py is equiped to handle object references instead of pk ints, and permissions.py is changed to pass those refs. --- awx/api/permissions.py | 6 +- awx/api/renderers.py | 3 +- awx/main/access.py | 99 ++++++++++++------------ awx/main/tests/unit/api/test_generics.py | 24 +++--- awx/main/utils/common.py | 5 +- 5 files changed, 69 insertions(+), 68 deletions(-) diff --git a/awx/api/permissions.py b/awx/api/permissions.py index cf29f3454ad3..e01ab4d9fb4f 100644 --- a/awx/api/permissions.py +++ b/awx/api/permissions.py @@ -34,7 +34,7 @@ def check_head_permissions(self, request, view, obj=None): def check_get_permissions(self, request, view, obj=None): if hasattr(view, 'parent_model'): - parent_obj = get_object_or_400(view.parent_model, pk=view.kwargs['pk']) + parent_obj = view.get_parent_object() if not check_user_access(request.user, view.parent_model, 'read', parent_obj): return False @@ -44,12 +44,12 @@ def check_get_permissions(self, request, view, obj=None): def check_post_permissions(self, request, view, obj=None): if hasattr(view, 'parent_model'): - parent_obj = get_object_or_400(view.parent_model, pk=view.kwargs['pk']) + parent_obj = view.get_parent_object() if not check_user_access(request.user, view.parent_model, 'read', parent_obj): return False if hasattr(view, 'parent_key'): - if not check_user_access(request.user, view.model, 'add', {view.parent_key: parent_obj.pk}): + if not check_user_access(request.user, view.model, 'add', {view.parent_key: parent_obj}): return False return True elif getattr(view, 'is_job_start', False): diff --git a/awx/api/renderers.py b/awx/api/renderers.py index 006057a09bee..f6fb089d4744 100644 --- a/awx/api/renderers.py +++ b/awx/api/renderers.py @@ -48,7 +48,8 @@ def get_rendered_html_form(self, data, view, method, request): obj = getattr(view, 'object', None) if obj is None and hasattr(view, 'get_object') and hasattr(view, 'retrieve'): try: - obj = view.get_object() + view.object = view.get_object() + obj = view.object except Exception: obj = None with override_method(view, request, method) as request: diff --git a/awx/main/access.py b/awx/main/access.py index d73b87738993..bf72dd3d3433 100644 --- a/awx/main/access.py +++ b/awx/main/access.py @@ -17,7 +17,12 @@ from rest_framework.exceptions import ParseError, PermissionDenied, ValidationError # AWX -from awx.main.utils import * # noqa +from awx.main.utils import ( + get_object_or_400, + get_pk_from_dict, + to_python_boolean, + get_licenser, +) from awx.main.models import * # noqa from awx.main.models.unified_jobs import ACTIVE_STATES from awx.main.models.mixins import ResourceMixin @@ -36,6 +41,36 @@ } +def get_object_from_data(field, Model, data, obj=None): + """ + Utility method to obtain related object in data according to fallbacks: + - if data contains key with pointer to Django object, return that + - if contains integer, get object from database + - if this does not work, raise exception + """ + try: + raw_value = data[field] + except KeyError: + # Calling method needs to deal with non-existence of key + raise ParseError(_("Required related field %s for permission check." % field)) + + if isinstance(raw_value, Model): + return raw_value + elif raw_value is None: + return None + else: + try: + new_pk = int(raw_value) + # Avoid database query by comparing pk to model for similarity + if obj and new_pk == getattr(obj, '%s_id' % field, None): + return getattr(obj, field) + else: + # Get the new resource from the database + return get_object_or_400(Model, pk=new_pk) + except (TypeError, ValueError): + raise ParseError(_("Bad data found in related field %s." % field)) + + class StateConflict(ValidationError): status_code = 409 @@ -205,24 +240,8 @@ def check_related(self, field, Model, data, role_field='admin_role', # Use reference object's related fields, if given new = getattr(data['reference_obj'], field) elif data and field in data: - # Obtain the resource specified in `data` - raw_value = data[field] - if isinstance(raw_value, Model): - new = raw_value - elif raw_value is None: - new = None - else: - try: - new_pk = int(raw_value) - # Avoid database query by comparing pk to model for similarity - if obj and new_pk == getattr(obj, '%s_id' % field, None): - changed = False - else: - # Get the new resource from the database - new = get_object_or_400(Model, pk=new_pk) - except (TypeError, ValueError): - raise ParseError(_("Bad data found in related field %s." % field)) - elif data is None or field not in data: + new = get_object_from_data(field, Model, data, obj=obj) + else: changed = False # Obtain existing related resource @@ -940,17 +959,14 @@ def can_read(self, obj): def can_add(self, data): if not data: # So the browseable API will work return True - user_pk = get_pk_from_dict(data, 'user') - if user_pk: - user_obj = get_object_or_400(User, pk=user_pk) + if data and data.get('user', None): + user_obj = get_object_from_data('user', User, data) return check_user_access(self.user, User, 'change', user_obj, None) - team_pk = get_pk_from_dict(data, 'team') - if team_pk: - team_obj = get_object_or_400(Team, pk=team_pk) + if data and data.get('team', None): + team_obj = get_object_from_data('team', Team, data) return check_user_access(self.user, Team, 'change', team_obj, None) - organization_pk = get_pk_from_dict(data, 'organization') - if organization_pk: - organization_obj = get_object_or_400(Organization, pk=organization_pk) + if data and data.get('organization', None): + organization_obj = get_object_from_data('organization', Organization, data) return check_user_access(self.user, Organization, 'change', organization_obj, None) return False @@ -1173,9 +1189,8 @@ def get_value(Class, field): if reference_obj: return getattr(reference_obj, field, None) else: - pk = get_pk_from_dict(data, field) - if pk: - return get_object_or_400(Class, pk=pk) + if data and data.get(field, None): + return get_object_from_data(field, Class, data) else: return None @@ -1261,23 +1276,6 @@ def changes_are_non_sensitive(self, obj, data): return False return True - def can_update_sensitive_fields(self, obj, data): - project_id = data.get('project', obj.project.id if obj.project else None) - inventory_id = data.get('inventory', obj.inventory.id if obj.inventory else None) - credential_id = data.get('credential', obj.credential.id if obj.credential else None) - vault_credential_id = data.get('credential', obj.vault_credential.id if obj.vault_credential else None) - - if project_id and self.user not in Project.objects.get(pk=project_id).use_role: - return False - if inventory_id and self.user not in Inventory.objects.get(pk=inventory_id).use_role: - return False - if credential_id and self.user not in Credential.objects.get(pk=credential_id).use_role: - return False - if vault_credential_id and self.user not in Credential.objects.get(pk=vault_credential_id).use_role: - return False - - return True - def can_delete(self, obj): is_delete_allowed = self.user.is_superuser or self.user in obj.admin_role if not is_delete_allowed: @@ -1387,9 +1385,8 @@ def can_add(self, data, validate_license=True): add_data = dict(data.items()) # If a job template is provided, the user should have read access to it. - job_template_pk = get_pk_from_dict(data, 'job_template') - if job_template_pk: - job_template = get_object_or_400(JobTemplate, pk=job_template_pk) + if data and data.get('job_template', None): + job_template = get_object_from_data('job_template', JobTemplate, data) add_data.setdefault('inventory', job_template.inventory.pk) add_data.setdefault('project', job_template.project.pk) add_data.setdefault('job_type', job_template.job_type) diff --git a/awx/main/tests/unit/api/test_generics.py b/awx/main/tests/unit/api/test_generics.py index 10baf7eab1fd..62eac9d99c7f 100644 --- a/awx/main/tests/unit/api/test_generics.py +++ b/awx/main/tests/unit/api/test_generics.py @@ -242,28 +242,28 @@ def mock_request(self): ), method='GET') - def mock_view(self): + def mock_view(self, parent=None): view = ResourceAccessList() view.parent_model = Organization view.kwargs = {'pk': 4} + if parent: + view.get_parent_object = lambda: parent return view def test_parent_access_check_failed(self, mocker, mock_organization): - with mocker.patch('awx.api.permissions.get_object_or_400', return_value=mock_organization): - mock_access = mocker.MagicMock(__name__='for logger', return_value=False) - with mocker.patch('awx.main.access.BaseAccess.can_read', mock_access): - with pytest.raises(PermissionDenied): - self.mock_view().check_permissions(self.mock_request()) - mock_access.assert_called_once_with(mock_organization) + mock_access = mocker.MagicMock(__name__='for logger', return_value=False) + with mocker.patch('awx.main.access.BaseAccess.can_read', mock_access): + with pytest.raises(PermissionDenied): + self.mock_view(parent=mock_organization).check_permissions(self.mock_request()) + mock_access.assert_called_once_with(mock_organization) def test_parent_access_check_worked(self, mocker, mock_organization): - with mocker.patch('awx.api.permissions.get_object_or_400', return_value=mock_organization): - mock_access = mocker.MagicMock(__name__='for logger', return_value=True) - with mocker.patch('awx.main.access.BaseAccess.can_read', mock_access): - self.mock_view().check_permissions(self.mock_request()) - mock_access.assert_called_once_with(mock_organization) + mock_access = mocker.MagicMock(__name__='for logger', return_value=True) + with mocker.patch('awx.main.access.BaseAccess.can_read', mock_access): + self.mock_view(parent=mock_organization).check_permissions(self.mock_request()) + mock_access.assert_called_once_with(mock_organization) def test_related_search_reverse_FK_field(): diff --git a/awx/main/utils/common.py b/awx/main/utils/common.py index 58d795567f6c..291ff0722efe 100644 --- a/awx/main/utils/common.py +++ b/awx/main/utils/common.py @@ -724,7 +724,10 @@ def get_pk_from_dict(_dict, key): Helper for obtaining a pk from user data dict or None if not present. ''' try: - return int(_dict[key]) + val = _dict[key] + if isinstance(val, object) and hasattr(val, 'id'): + return val.id # return id if given model object + return int(val) except (TypeError, KeyError, ValueError): return None