Skip to content

Commit

Permalink
Merge a31a1ad into 23eaa8a
Browse files Browse the repository at this point in the history
  • Loading branch information
erinspace committed Jan 31, 2018
2 parents 23eaa8a + a31a1ad commit 2bb6c52
Show file tree
Hide file tree
Showing 11 changed files with 323 additions and 9 deletions.
31 changes: 31 additions & 0 deletions api/base/fields.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from django import forms
from django.db.models import Q
from django.core.exceptions import ValidationError


class NullModelMultipleChoiceCaseInsensitiveField(forms.ModelMultipleChoiceField):

def __init__(self, *args, **kwargs):
# use the default empty label
kwargs.pop('empty_label')
super(NullModelMultipleChoiceCaseInsensitiveField, self).__init__(*args, **kwargs)

def clean(self, value):
# let a custom filter handle the actual filtering for null values later in the qs
if value == 'null':
return value

try:
return super(NullModelMultipleChoiceCaseInsensitiveField, self).clean(value)

except ValidationError as validation_error:
# Check to make sure the validation error wasn't because of a case sensitive relationship query
q = Q()
for choice in value:
q |= Q(**{'{}__iexact'.format(self.to_field_name): choice})
queryset = self.queryset.filter(q)

if not queryset:
raise validation_error

return queryset
145 changes: 145 additions & 0 deletions api/base/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
import operator
import re

import django_filters
from django.db import models
import pytz
from guardian.shortcuts import get_objects_for_user
from api.base import utils
from api.base.fields import NullModelMultipleChoiceCaseInsensitiveField
from api.base.exceptions import (InvalidFilterComparisonType,
InvalidFilterError, InvalidFilterFieldError,
InvalidFilterMatchType, InvalidFilterOperator,
Expand Down Expand Up @@ -66,6 +69,148 @@ def filter_queryset(self, request, queryset, view):
return queryset


class MultiValueCharFilter(django_filters.BaseInFilter, django_filters.filters.CharFilter):

def filter(self, qs, value):
q = Q()
values = value or []
for value in values:
q |= Q(**{'{}__icontains'.format(self.name): value.strip()})
qs = qs.filter(q)

return qs


class NullModelMultipleChoiceFilter(django_filters.ModelChoiceFilter):

field_class = NullModelMultipleChoiceCaseInsensitiveField

def __init__(self, *args, **kwargs):
# use base_name to filter later on for a null relationship
self.base_name = kwargs.pop('base_name')
super(NullModelMultipleChoiceFilter, self).__init__(*args, **kwargs)

def filter(self, qs, value):
if value == 'null':
return qs.filter(Q(**{'{}__isnull'.format(self.base_name): True}))

q = Q()
values = value or []
for value in values:
q |= Q(**{self.name: value})
return qs.filter(q)


class JSONAPIFilterSet(django_filters.rest_framework.FilterSet):

QUERY_PATTERN = re.compile(r'^filter\[(?P<fields>((?:,*\s*\w+)*))\](\[(?P<op>\w+)\])?$')
FILTER_FIELDS = re.compile(r'(?:,*\s*(\w+)+)')

# TODO - there must be a better way to recognize these
MANY_TO_MANY_FIELDS = ['tags', 'contributors']
DATE_FIELDS = ['created', 'modified']

def __init__(self, data=None, *args, **kwargs):
self.or_fields = {}
self.operator_fields = []
field_names = []
if data:
new_data = {}
for key, value in data.iteritems():
match = self.QUERY_PATTERN.match(key)
if match:
match_dict = match.groupdict()
fields = match_dict['fields']
field_names = self.FILTER_FIELDS.findall(fields.strip())
op = match_dict.get('op', None)
# some filter values can be passed in brackets, with spaces between them
value = value.replace('[', '').replace(']', '')
values = value.split(',')
if len(field_names) > 1:
self.or_fields[frozenset(field_names)] = values
elif op and op != 'eq':
self.operator_fields.append({'field': field_names[0], 'value': value, 'op': op})
else:
new_data.update({field_names[0]: value})
data = self.postprocess_data(new_data)
super(JSONAPIFilterSet, self).__init__(data=data, *args, **kwargs)

for field in field_names:
if field not in self.form.fields:
raise InvalidFilterFieldError(parameter='filter', value=field)

def postprocess_data(self, data):
data_to_return = {}
for field_name, value in data.iteritems():
# to filter on a relationship field, values must be in a list
if field_name in self.MANY_TO_MANY_FIELDS and value != 'null':
if field_name == 'contributors':
field_name = '_contributors'
value = value if type(value) == list else value.split(',')
if value == 'true' or value == 'false':
value = value.title()
if field_name in self.DATE_FIELDS:
try:
value_datetime = date_parser.parse(value, ignoretz=False)
if not value_datetime.tzinfo:
value = value_datetime.replace(tzinfo=pytz.utc).isoformat()
except ValueError:
raise InvalidFilterValue(
value=value,
field_type='date'
)

data_to_return[field_name] = value

return data_to_return

@property
def qs(self, *args, **kwargs):
self.form.is_valid()
qs = super(JSONAPIFilterSet, self).qs
if self.operator_fields:
qs = self.filter_operators(qs)
return self.filter_groups(qs)

def filter_groups(self, qs):
for group, values in self.or_fields.iteritems():
group_q = Q()
for field in group:
if field not in self.form.fields.keys():
raise InvalidFilterFieldError(parameter='filter', value=field)
for value in values:
group_q |= Q(**{self.filters[field].name: value})
qs = qs.filter(group_q)

return qs

def filter_operators(self, qs):
for field_dict in self.operator_fields:
value = field_dict['value']
field = field_dict['field']
op = field_dict['op']

if field in self.DATE_FIELDS:
value = datetime.datetime.strptime(value, '%Y-%m-%d').replace(tzinfo=pytz.UTC)
if op == 'ne':
qs = qs.filter(~Q(**{field: value}))
else:
filter_left = '{}__{}'.format(field, op)
qs = qs.filter(Q(**{filter_left: value}))

return qs

class Meta:
filter_overrides = {
models.CharField: {
'filter_class': MultiValueCharFilter,
},
models.TextField: {
'filter_class': MultiValueCharFilter,
}
}


class FilterMixin(object):
""" View mixin with helper functions for filtering. """

Expand Down
48 changes: 45 additions & 3 deletions api/nodes/filters.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import django_filters
from copy import deepcopy

from django.db.models import Q
from django.conf import settings

from api.base.exceptions import InvalidFilterOperator, InvalidFilterValue
from api.base.filters import ListFilterMixin
from api.base.filters import ListFilterMixin, JSONAPIFilterSet, NullModelMultipleChoiceFilter, MultiValueCharFilter
from api.base import utils

from osf.models import NodeRelation, AbstractNode
from osf.models import NodeRelation, AbstractNode, Tag, OSFUser


class NodesFilterMixin(ListFilterMixin):
Expand Down Expand Up @@ -69,3 +70,44 @@ def build_query_from_field(self, field_name, operation):
return ~not_preprint_query if utils.is_truthy(operation['value']) else not_preprint_query

return super(NodesFilterMixin, self).build_query_from_field(field_name, operation)


class NodeFilterSet(JSONAPIFilterSet):
id = MultiValueCharFilter(name='guids___id')
public = django_filters.BooleanFilter(name='is_public')
tags = NullModelMultipleChoiceFilter(name='tags__name', queryset=Tag.objects.all(), to_field_name='name', lookup_expr='in', base_name='tags')
category = django_filters.ChoiceFilter(choices=settings.NODE_CATEGORY_MAP.items())
preprint = django_filters.CharFilter(method='filter_preprint')
contributors = django_filters.ModelMultipleChoiceFilter(name='_contributors__guids___id', queryset=OSFUser.objects.all(), to_field_name='_guids___id')
root = django_filters.CharFilter(method='filter_root')
parent = django_filters.CharFilter(method='filter_parent')

def filter_parent(self, queryset, name, value):
if value == 'null':
return queryset.get_roots()
parent = utils.get_object_or_error(AbstractNode, value, display_name='parent')
node_ids = NodeRelation.objects.filter(parent=parent, is_node_link=False).values_list('child_id', flat=True)

return queryset.filter(id__in=node_ids)

def filter_root(self, queryset, name, value):
if value == 'null':
raise InvalidFilterValue(value=value)
return queryset.filter(root__guids___id=value)

def filter_preprint(self, queryset, name, value):
preprint_filters = (
Q(preprint_file=None) |
Q(_is_preprint_orphan=True) |
Q(_has_abandoned_preprint=True)
)
return queryset.exclude(preprint_filters) if utils.is_truthy(value) else queryset.filter(preprint_filters)

class Meta(JSONAPIFilterSet.Meta):
model = AbstractNode
fields = [
'title',
'description',
'created',
'modified',
]
11 changes: 8 additions & 3 deletions api/nodes/views.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import re

import django_filters
from django.apps import apps
from django.db.models import Q, OuterRef, Exists
from django.utils import timezone
from rest_framework import generics, permissions as drf_permissions
from rest_framework.exceptions import PermissionDenied, ValidationError, NotFound, MethodNotAllowed, NotAuthenticated
from rest_framework.response import Response
from rest_framework.status import HTTP_204_NO_CONTENT
from rest_framework.filters import OrderingFilter

from addons.osfstorage.models import OsfStorageFolder
from api.addons.serializers import NodeAddonFolderSerializer
Expand Down Expand Up @@ -59,7 +61,7 @@
from api.identifiers.views import IdentifierList
from api.institutions.serializers import InstitutionSerializer
from api.logs.serializers import NodeLogSerializer
from api.nodes.filters import NodesFilterMixin
from api.nodes.filters import NodeFilterSet, NodesFilterMixin
from api.nodes.permissions import (
IsAdmin,
IsPublic,
Expand Down Expand Up @@ -172,7 +174,7 @@ def get_draft(self, draft_id=None):
return draft


class NodeList(JSONAPIBaseView, bulk_views.BulkUpdateJSONAPIView, bulk_views.BulkDestroyJSONAPIView, bulk_views.ListBulkCreateJSONAPIView, NodesFilterMixin, WaterButlerMixin):
class NodeList(JSONAPIBaseView, bulk_views.BulkUpdateJSONAPIView, bulk_views.BulkDestroyJSONAPIView, bulk_views.ListBulkCreateJSONAPIView, WaterButlerMixin):
"""Nodes that represent projects and components. *Writeable*.
Paginated list of nodes ordered by their `modified`. Each resource contains the full representation of the
Expand Down Expand Up @@ -275,6 +277,9 @@ class NodeList(JSONAPIBaseView, bulk_views.BulkUpdateJSONAPIView, bulk_views.Bul
view_category = 'nodes'
view_name = 'node-list'

filter_backends = (django_filters.rest_framework.DjangoFilterBackend, OrderingFilter)
filter_class = NodeFilterSet

ordering = ('-modified', ) # default ordering

# overrides NodesFilterMixin
Expand All @@ -299,7 +304,7 @@ def get_queryset(self):
raise PermissionDenied
return nodes
else:
return self.get_queryset_from_request()
return self.get_default_queryset()

# overrides ListBulkCreateJSONAPIView, BulkUpdateJSONAPIView, BulkDestroyJSONAPIView
def get_serializer_class(self):
Expand Down
14 changes: 14 additions & 0 deletions api/users/filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import django_filters

from osf.models import OSFUser
from api.base.filters import JSONAPIFilterSet, MultiValueCharFilter


class UserFilterSet(JSONAPIFilterSet):

full_name = MultiValueCharFilter(name='fullname', lookup_expr='icontains')
id = django_filters.CharFilter(name='guids___id')

class Meta(JSONAPIFilterSet.Meta):
model = OSFUser
fields = ['id', 'full_name', 'given_name', 'middle_names', 'family_name']
10 changes: 8 additions & 2 deletions api/users/views.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import django_filters
from django.apps import apps

from api.addons.views import AddonSettingsMixin
Expand All @@ -17,6 +18,7 @@
from api.nodes.serializers import NodeSerializer
from api.preprints.serializers import PreprintSerializer
from api.registrations.serializers import RegistrationSerializer
from api.users.filters import UserFilterSet
from api.users.permissions import (CurrentUser, ReadOnlyOrCurrentUser,
ReadOnlyOrCurrentUserRelationship)
from api.users.serializers import (UserAddonSettingsSerializer,
Expand Down Expand Up @@ -85,7 +87,8 @@ def get_user(self, check_permissions=True):
return obj


class UserList(JSONAPIBaseView, generics.ListAPIView, ListFilterMixin):
class UserList(JSONAPIBaseView, generics.ListAPIView):

"""List of users registered on the OSF.
Paginated list of users ordered by the date they registered. Each resource contains the full representation of the
Expand Down Expand Up @@ -145,6 +148,9 @@ class UserList(JSONAPIBaseView, generics.ListAPIView, ListFilterMixin):

serializer_class = UserSerializer

filter_backends = (django_filters.rest_framework.DjangoFilterBackend,)
filter_class = UserFilterSet

ordering = ('-date_registered')
view_category = 'users'
view_name = 'user-list'
Expand All @@ -156,7 +162,7 @@ def get_default_queryset(self):

# overrides ListCreateAPIView
def get_queryset(self):
return self.get_queryset_from_request()
return self.get_default_queryset()


class UserDetail(JSONAPIBaseView, generics.RetrieveUpdateAPIView, UserMixin):
Expand Down
1 change: 1 addition & 0 deletions api_tests/nodes/filters/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
NodeRelationFactory,
ProjectFactory,
)
from tests.utils import assert_items_equal
from framework.auth.core import Auth


Expand Down
2 changes: 1 addition & 1 deletion api_tests/nodes/views/test_node_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,7 @@ def test_incorrect_filtering_field_not_logged_in(self, app):
assert res.status_code == 400
errors = res.json['errors']
assert len(errors) == 1
assert errors[0]['detail'] == '\'notafield\' is not a valid field for this endpoint.'
assert errors[0]['detail'] == "Value 'notafield' is not a filterable field."

def test_filtering_on_root(self, app, user_one):
root = ProjectFactory(is_public=True)
Expand Down
Loading

0 comments on commit 2bb6c52

Please sign in to comment.