Skip to content

Commit

Permalink
Merge 8e57a4d into 5eeb2bd
Browse files Browse the repository at this point in the history
  • Loading branch information
romcheg committed Dec 16, 2019
2 parents 5eeb2bd + 8e57a4d commit cae66f4
Show file tree
Hide file tree
Showing 7 changed files with 287 additions and 1 deletion.
41 changes: 40 additions & 1 deletion src/ralph/lib/custom_fields/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class CustomFieldAdmin(RalphAdmin):
list_filter = ['type']
fields = [
'name', 'attribute_name', 'type', 'choices', 'default_value',
'use_as_configuration_variable'
'managing_group', 'use_as_configuration_variable'
]
readonly_fields = ['attribute_name']

Expand Down Expand Up @@ -135,3 +135,42 @@ def changeform_view(
return super().changeform_view(
request, object_id, form_url, extra_context
)

def _create_formsets(self, request, obj, change):
"""
Helper function to generate formsets for add/change_view
99% of this function contains unaltered code from Django. The only
alternation is that it passes request objects to custom field form sets
in order to make it impossible to edit restricted custom fields.
"""
formsets = []
inline_instances = []
prefixes = {}
get_formsets_args = [request]
if change:
get_formsets_args.append(obj)
for FormSet, inline in self.get_formsets_with_inlines(*get_formsets_args): # noqa: E501
prefix = FormSet.get_default_prefix()
prefixes[prefix] = prefixes.get(prefix, 0) + 1
if prefixes[prefix] != 1 or not prefix:
prefix = "%s-%s" % (prefix, prefixes[prefix])
formset_params = {
'instance': obj,
'prefix': prefix,
'queryset': inline.get_queryset(request),
}

if issubclass(FormSet, CustomFieldValueFormSet):
formset_params['request'] = request

if request.method == 'POST':
formset_params.update({
'data': request.POST,
'files': request.FILES,
'save_as_new': '_saveasnew' in request.POST
})
formsets.append(FormSet(**formset_params))
inline_instances.append(inline)
return formsets, inline_instances
45 changes: 45 additions & 0 deletions src/ralph/lib/custom_fields/api/viewsets.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from django.contrib.contenttypes.models import ContentType
from rest_framework import viewsets
from rest_framework.response import Response
from rest_framework.status import HTTP_403_FORBIDDEN

from ..models import CustomFieldValue
from .serializers import (
Expand Down Expand Up @@ -42,6 +44,12 @@ def _get_related_model_info(self):
}
return info

def _user_can_manage_customfield(self, user, custom_field):
return (
custom_field.managing_group is None or
user.groups.filter(pk=custom_field.managing_group.pk).exists()
)

def filter_queryset(self, queryset):
queryset = super().filter_queryset(queryset)
return queryset.filter(**self._get_related_model_info())
Expand All @@ -51,3 +59,40 @@ def get_serializer(self, *args, **kwargs):
if kwargs.get('data') is not None:
kwargs['data'].update(self._get_related_model_info())
return super().get_serializer(*args, **kwargs)

def create(self, request, *args, **kwargs):
"""
Enforce user to be in a required group for restricted custom fields.
"""
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)

custom_field = serializer.validated_data['custom_field']
if self._user_can_manage_customfield(request.user, custom_field):
return super().create(request, *args, **kwargs)
else:
return Response(status=HTTP_403_FORBIDDEN)

def update(self, request, *args, **kwargs):
"""
Enforce user to be in a required group for restricted custom fields.
"""
custom_field = self.get_object().custom_field
if self._user_can_manage_customfield(request.user, custom_field):
return super().update(request, *args, **kwargs)
else:
return Response(status=HTTP_403_FORBIDDEN)

def destroy(self, request, *args, **kwargs):
"""
Enforce user to be in a required group for restricted custom fields.
"""
custom_field = self.get_object().custom_field

if self._user_can_manage_customfield(request.user, custom_field):
return super().destroy(request, *args, **kwargs)
else:
return Response(status=HTTP_403_FORBIDDEN)
10 changes: 10 additions & 0 deletions src/ralph/lib/custom_fields/forms.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from django import forms
from django.contrib.contenttypes.forms import BaseGenericInlineFormSet
from django.contrib.contenttypes.models import ContentType
from django.db.models import Q
from django.utils.translation import ugettext_lazy as _
from django.utils.translation import ugettext

Expand Down Expand Up @@ -39,6 +40,15 @@ def save(self, *args, **kwargs):


class CustomFieldValueFormSet(BaseGenericInlineFormSet):

def __init__(self, request=None, queryset=None, *args, **kwargs):
queryset = queryset.filter(
Q(custom_field__managing_group__isnull=True) |
Q(custom_field__managing_group__in=request.user.groups.all())
)

super().__init__(queryset=queryset, *args, **kwargs)

def _construct_form(self, i, **kwargs):
form = super()._construct_form(i, **kwargs)
# fix for https://code.djangoproject.com/ticket/12028, together with
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# -*- coding: utf-8 -*-
from __future__ import unicode_literals

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('auth', '0006_require_contenttypes_0002'),
('custom_fields', '0004_auto_20161214_1126'),
]

operations = [
migrations.AddField(
model_name='customfield',
name='managing_group',
field=models.ForeignKey(blank=True, null=True, help_text='When set, only members of the specified group will be allowed to set, change or unset values of this custom field for objects.', to='auth.Group'),
),
]
9 changes: 9 additions & 0 deletions src/ralph/lib/custom_fields/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from dj.choices import Choices
from django import forms
from django.contrib.auth.models import Group
from django.contrib.contenttypes import generic

from django.contrib.contenttypes.models import ContentType
Expand Down Expand Up @@ -72,6 +73,14 @@ class CustomField(AdminAbsoluteUrlMixin, TimeStampMixin, models.Model):
blank=True,
default='',
)
managing_group = models.ForeignKey(
Group, blank=True, null=True,
help_text=_(
"When set, only members of the specified group will be "
"allowed to set, change or unset values of this custom field "
"for objects."
)
)
use_as_configuration_variable = models.BooleanField(
default=False,
help_text=_(
Expand Down
53 changes: 53 additions & 0 deletions src/ralph/lib/custom_fields/tests/test_admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from django.core.urlresolvers import reverse
from django.test import RequestFactory, TestCase

from ralph.accounts.tests.factories import GroupFactory
from ..models import CustomField, CustomFieldTypes, CustomFieldValue
from .models import ModelA, ModelB, SomeModel

Expand Down Expand Up @@ -242,3 +243,55 @@ def test_clearing_values_of_children_objects(self):
response = self.client.post(self.a1.get_absolute_url(), data)
self.assertEqual(response.status_code, 302)
self.assertNotIn(self.cfv1, list(self.sm1.custom_fields.all()))

def test_custom_field_not_in_form_for_nonmatching_managing_group(self):
self.custom_field_str.managing_group = GroupFactory()
self.custom_field_str.save()

response = self.client.get(self.sm1.get_absolute_url())

self.assertEqual(1, len(response.context_data['custom_fields_all']))
self.assertEqual(
'sample_value',
response.context_data['custom_fields_all'][0]['value']
)

filled_in_custom_field_forms = [
form
for form in response.context_data['inline_admin_formsets'][0].formset.forms
if form.fields['id'].initial is not None
]

self.assertEqual(0, len(filled_in_custom_field_forms))


def test_custom_field_in_form_for_matching_managing_group(self):
group = GroupFactory()

self.user.groups.add(group)
self.custom_field_str.managing_group = group

self.user.save()
self.custom_field_str.save()

response = self.client.get(self.sm1.get_absolute_url())

self.assertEqual(1, len(response.context_data['custom_fields_all']))
self.assertEqual(
'sample_value',
response.context_data['custom_fields_all'][0]['value']
)

filled_in_custom_field_forms = [
form
for form in response.context_data['inline_admin_formsets'][0].formset.forms
if form.fields['id'].initial is not None
]

self.assertEqual(1, len(filled_in_custom_field_forms))
form = filled_in_custom_field_forms[0]

self.assertEqual(
self.cfv1.id,
form.fields['id'].initial
)
110 changes: 110 additions & 0 deletions src/ralph/lib/custom_fields/tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
# -*- coding: utf-8 -*-
from django.contrib.auth import get_user_model
from django.contrib.auth.models import Group
from django.core.urlresolvers import reverse
from rest_framework import status
from rest_framework.test import APITestCase

from ralph.accounts.models import RalphUser
from ralph.accounts.tests.factories import GroupFactory
from ralph.tests.factories import UserFactory
from ..models import CustomField, CustomFieldTypes, CustomFieldValue
from ..signals import api_post_create, api_post_update
from .models import ModelA, ModelB, SomeModel
Expand Down Expand Up @@ -205,6 +209,112 @@ def test_add_new_customfield_value_should_pass(self):
self.assertEqual(cfv.custom_field, self.custom_field_choices)
self.assertEqual(cfv.value, 'qwerty')

def test_add_new_customfield_value_with_unmatching_managing_group_should_fail(self): # noqa" E501

self.custom_field_str.managing_group = GroupFactory()
self.custom_field_str.save()

some_object = SomeModel.objects.create(name='DEADBEEF')

url = reverse(self.list_view_name, args=(some_object.id,))
data = {
'value': 'qwerty',
'custom_field': self.custom_field_str.id,
}

response = self.client.post(url, data=data, format='json')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)

def test_add_new_customfield_value_with_matching_managing_group_should_succeed(self): # noqa" E501
group = GroupFactory()
self.user.groups.add(group)
self.custom_field_str.managing_group = group

self.custom_field_str.save()
self.user.save()

some_object = SomeModel.objects.create(name='DEADBEEF')

url = reverse(self.list_view_name, args=(some_object.id,))
data = {
'value': 'qwerty',
'custom_field': self.custom_field_str.id,
}

response = self.client.post(url, data=data, format='json')
self.assertEqual(response.status_code, status.HTTP_201_CREATED)

def test_update_customfield_value_with_unmatching_managing_group_should_fail(self): # noqa: E501
self.custom_field_str.managing_group = GroupFactory()
self.custom_field_str.save()

url = reverse(
self.detail_view_name,
kwargs={'pk': self.cfv1.pk, 'object_pk': self.cfv1.object_id}
)
data = {
'value': 'NEW-VALUE',
'custom_field': self.custom_field_str.id,
}
response = self.client.put(url, data=data, format='json')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)

def test_update_customfield_value_with_matching_managing_group_should_pass(self): # noqa: E501
group = GroupFactory()
self.user.groups.add(group)
self.custom_field_str.managing_group = group

self.custom_field_str.save()
self.user.save()

url = reverse(
self.detail_view_name,
kwargs={'pk': self.cfv1.pk, 'object_pk': self.cfv1.object_id}
)
data = {
'value': 'NEW-VALUE',
'custom_field': self.custom_field_str.id,
}
response = self.client.put(url, data=data, format='json')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.cfv1.refresh_from_db()
self.assertEqual(self.cfv1.object, self.sm1)
self.assertEqual(self.cfv1.custom_field, self.custom_field_str)
self.assertEqual(self.cfv1.value, 'NEW-VALUE')


def test_delete_custom_field_value_with_unmatching_managing_group_should_fail(self): # noqa: E501
self.custom_field_str.managing_group = GroupFactory()
self.custom_field_str.save()

url = reverse(
self.detail_view_name,
kwargs={'pk': self.cfv1.pk, 'object_pk': self.cfv1.object_id}
)
response = self.client.delete(url, format='json')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertEqual(
CustomFieldValue.objects.filter(pk=self.cfv1.pk).count(), 1
)

def test_delete_custom_field_value_with_matching_managing_group_should_pass(self): # noqa: E501
group = GroupFactory()
self.user.groups.add(group)
self.custom_field_str.managing_group = group

self.custom_field_str.save()
self.user.save()

url = reverse(
self.detail_view_name,
kwargs={'pk': self.cfv1.pk, 'object_pk': self.cfv1.object_id}
)
response = self.client.delete(url, format='json')
self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
self.assertEqual(
CustomFieldValue.objects.filter(pk=self.cfv1.pk).count(), 0
)

def test_add_new_customfield_value_should_send_api_post_create_signal(self): # noqa: E501
self._sig_called_with_instance = None

Expand Down

0 comments on commit cae66f4

Please sign in to comment.