Skip to content

Commit

Permalink
Merge 48634a5 into 937ae01
Browse files Browse the repository at this point in the history
  • Loading branch information
caseyrollins committed Jun 11, 2018
2 parents 937ae01 + 48634a5 commit 7ed9950
Show file tree
Hide file tree
Showing 9 changed files with 155 additions and 30 deletions.
25 changes: 20 additions & 5 deletions api/base/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from django.utils.http import urlquote
from django.core.exceptions import ObjectDoesNotExist
from django.db.models import OuterRef, Exists, Q
from django.db.models import OuterRef, Exists, Q, QuerySet, F
from rest_framework.exceptions import NotFound
from rest_framework.reverse import reverse

Expand Down Expand Up @@ -76,10 +76,24 @@ def absolute_reverse(view_name, query_kwargs=None, args=None, kwargs=None):
return url


def get_object_or_error(model_cls, query_or_pk, request, display_name=None):
def get_object_or_error(model_or_qs, query_or_pk=None, request=None, display_name=None):
if not request:
# for backwards compat with existing get_object_or_error usages
raise TypeError('request is a required argument')

obj = query = None
model_cls = model_or_qs
select_for_update = check_select_for_update(request)
if isinstance(query_or_pk, basestring):

if isinstance(model_or_qs, QuerySet):
# they passed a queryset
model_cls = model_or_qs.model
try:
obj = model_or_qs.select_for_update().get() if select_for_update else model_or_qs.get()
except model_cls.DoesNotExist:
raise NotFound

elif isinstance(query_or_pk, basestring):
# they passed a 5-char guid as a string
if issubclass(model_cls, GuidMixin):
# if it's a subclass of GuidMixin we know it's primary_identifier_name
Expand Down Expand Up @@ -126,7 +140,7 @@ def get_object_or_error(model_cls, query_or_pk, request, display_name=None):

def default_node_list_queryset(model_cls):
assert model_cls in {Node, Registration}
return model_cls.objects.filter(is_deleted=False)
return model_cls.objects.filter(is_deleted=False).annotate(region=F('addons_osfstorage_node_settings__region___id'))

def default_node_permission_queryset(user, model_cls):
assert model_cls in {Node, Registration}
Expand All @@ -139,7 +153,8 @@ def default_node_list_permission_queryset(user, model_cls):
# **DO NOT** change the order of the querysets below.
# If get_roots() is called on default_node_list_qs & default_node_permission_qs,
# Django's alaising will break and the resulting QS will be empty and you will be sad.
return default_node_permission_queryset(user, model_cls) & default_node_list_queryset(model_cls)
qs = default_node_permission_queryset(user, model_cls) & default_node_list_queryset(model_cls)
return qs.annotate(region=F('addons_osfstorage_node_settings__region___id'))

def extend_querystring_params(url, params):
scheme, netloc, path, query, _ = urlparse.urlsplit(url)
Expand Down
11 changes: 10 additions & 1 deletion api/base/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from django_bulk_update.helper import bulk_update
from django.conf import settings as django_settings
from django.db import transaction
from django.db.models import F
from django.http import JsonResponse
from rest_framework import generics
from rest_framework import permissions as drf_permissions
Expand Down Expand Up @@ -524,7 +525,15 @@ class BaseLinkedList(JSONAPIBaseView, generics.ListAPIView):
def get_queryset(self):
auth = get_user_auth(self.request)

return self.get_node().linked_nodes.filter(is_deleted=False).exclude(type='osf.collection').can_view(user=auth.user, private_link=auth.private_link).order_by('-modified')
return (
self.get_node().linked_nodes
.filter(is_deleted=False)
.annotate(region=F('addons_osfstorage_node_settings__region___id'))
.exclude(region=None)
.exclude(type='osf.collection', region=None)
.can_view(user=auth.user, private_link=auth.private_link)
.order_by('-modified')
)


class WaterButlerMixin(object):
Expand Down
2 changes: 2 additions & 0 deletions api/institutions/views.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from django.db.models import F
from rest_framework import generics
from rest_framework import permissions as drf_permissions
from rest_framework import exceptions
Expand Down Expand Up @@ -118,6 +119,7 @@ def get_default_queryset(self):
institution.nodes.filter(is_public=True, is_deleted=False, type='osf.node')
.select_related('node_license', 'preprint_file')
.include('contributor__user__guids', 'root__guids', 'tags', limit_includes=10)
.annotate(region=F('addons_osfstorage_node_settings__region___id'))
)

# overrides RetrieveAPIView
Expand Down
39 changes: 35 additions & 4 deletions api/nodes/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from rest_framework import serializers as ser
from rest_framework import exceptions
from addons.base.exceptions import InvalidAuthError, InvalidFolderError
from addons.osfstorage.models import Region
from website.exceptions import NodeStateError
from osf.models import (Comment, DraftRegistration, Institution,
MetaSchema, AbstractNode, PrivateLink)
Expand Down Expand Up @@ -75,6 +76,20 @@ def update_institutions(node, new_institutions, user, post=False):
node.add_affiliated_institution(inst, user)


class RegionRelationshipField(RelationshipField):

def get_object(self, region_id):
try:
region = Region.objects.get(_id=region_id)
except Region.DoesNotExist:
raise exceptions.ValidationError(detail='Region {} is invalid.'.format(region_id))
return region

def to_internal_value(self, data):
region = self.get_object(data)
return {'region_id': region.id}


class NodeTagField(ser.Field):
def to_representation(self, obj):
if obj is not None:
Expand Down Expand Up @@ -319,10 +334,10 @@ class NodeSerializer(TaxonomizableSerializerMixin, JSONAPISerializer):
related_meta={'count': 'get_registration_count'}
))

region = RelationshipField(
region = RegionRelationshipField(
related_view='regions:region-detail',
related_view_kwargs={'region_id': '<osfstorage_region._id>'},
read_only=True
related_view_kwargs={'region_id': 'get_region_id'},
read_only=False
)

affiliated_institutions = RelationshipField(
Expand Down Expand Up @@ -473,14 +488,27 @@ def get_unread_comments_count(self, obj):
'node': node_comments
}

def get_region_id(self, obj):
try:
# use the annotated value if possible
region = obj.region
except AttributeError:
# use computed property if region annotation does not exist
# i.e. after creating a node
region = obj.osfstorage_region
return region._id

def create(self, validated_data):
request = self.context['request']
user = request.user
Node = apps.get_model('osf.Node')
tag_instances = []
affiliated_institutions = None
region_id = None
if 'affiliated_institutions' in validated_data:
affiliated_institutions = validated_data.pop('affiliated_institutions')
if 'region_id' in validated_data:
region_id = validated_data.pop('region_id')
if 'tags' in validated_data:
tags = validated_data.pop('tags')
for tag in tags:
Expand Down Expand Up @@ -530,7 +558,8 @@ def create(self, validated_data):
node.subjects.add(parent.subjects.all())
node.save()

region_id = self.context.get('region_id')
if not region_id:
region_id = self.context.get('region_id')
if region_id:
node_settings = node.get_addon('osfstorage')
node_settings.region_id = region_id
Expand All @@ -550,6 +579,8 @@ def update(self, node, validated_data):
if 'tags' in validated_data:
new_tags = set(validated_data.pop('tags', []))
node.update_tags(new_tags, auth=auth)
if 'region' in validated_data:
validated_data.pop('region')
if 'license_type' in validated_data or 'license' in validated_data:
license_details = get_license_details(node, validated_data)
validated_data['node_license'] = license_details
Expand Down
47 changes: 32 additions & 15 deletions api/nodes/views.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re

from django.apps import apps
from django.db.models import Q, OuterRef, Exists
from django.db.models import Q, OuterRef, Exists, F
from django.utils import timezone
from rest_framework import generics, permissions as drf_permissions
from rest_framework.exceptions import PermissionDenied, ValidationError, NotFound, MethodNotAllowed, NotAuthenticated
Expand Down Expand Up @@ -132,14 +132,13 @@ def get_node(self, check_object_permissions=True):
# If this is an embedded request, the node might be cached somewhere
node = self.request.parents[Node].get(self.kwargs[self.node_lookup_url_kwarg])

node_id = self.kwargs[self.node_lookup_url_kwarg]
if node is None:
node = get_object_or_error(
Node,
self.kwargs[self.node_lookup_url_kwarg],
self.request,
Node.objects.filter(guids___id=node_id).annotate(region=F('addons_osfstorage_node_settings__region___id')).exclude(region=None),
request=self.request,
display_name='node'
)

# Nodes that are folders/collections are treated as a separate resource, so if the client
# requests a collection through a node endpoint, we return a 404
if node.is_collection or node.is_registration:
Expand Down Expand Up @@ -234,17 +233,15 @@ def get_serializer_class(self):

def get_serializer_context(self):
context = super(NodeList, self).get_serializer_context()
region__id = self.request.query_params.get('region', None)
id = None
if region__id:
region_id = self.request.query_params.get('region', None)
if region_id:
try:
id = Region.objects.get(_id=region__id).id
region = Region.objects.get(_id=region_id)
except Region.DoesNotExist:
raise InvalidQueryStringError('Region {} is invalid.'.format(region__id))

context.update({
'region_id': id
})
raise InvalidQueryStringError('Region {} is invalid.'.format(region_id))
context.update({
'region_id': region.id
})
return context

# overrides ListBulkCreateJSONAPIView
Expand Down Expand Up @@ -638,6 +635,19 @@ def perform_create(self, serializer):
user = self.request.user
serializer.save(creator=user, parent=self.get_node())

def get_serializer_context(self):
context = super(NodeChildrenList, self).get_serializer_context()
region_id = self.request.query_params.get('region', None)
if region_id:
try:
region = Region.objects.get(_id=region_id)
except Region.DoesNotExist:
raise InvalidQueryStringError('Region {} is invalid.'.format(region_id))
context.update({
'region_id': region.id
})
return context


class NodeCitationDetail(JSONAPIBaseView, generics.RetrieveAPIView, NodeMixin):
"""The documentation for this endpoint can be found [here](https://developer.osf.io/#operation/nodes_citation_list).
Expand Down Expand Up @@ -896,7 +906,14 @@ class NodeForksList(JSONAPIBaseView, generics.ListCreateAPIView, NodeMixin, Node

# overrides ListCreateAPIView
def get_queryset(self):
all_forks = self.get_node().forks.exclude(type='osf.registration').exclude(is_deleted=True).order_by('-forked_date')
all_forks = (
self.get_node().forks
.annotate(region=F('addons_osfstorage_node_settings__region___id'))
.exclude(region=None)
.exclude(type='osf.registration')
.exclude(is_deleted=True)
.order_by('-forked_date')
)
auth = get_user_auth(self.request)

node_pks = [node.pk for node in all_forks if node.can_view(auth)]
Expand Down
15 changes: 15 additions & 0 deletions api/users/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@ class UserSerializer(JSONAPISerializer):
related_view_kwargs={'user_id': '<_id>'},
))

default_region = ShowIfCurrentUser(RelationshipField(
related_view='regions:region-detail',
related_view_kwargs={'region_id': 'get_default_region_id'},
read_only=True
))

class Meta:
type_ = 'users'

Expand All @@ -119,6 +125,15 @@ def get_can_view_reviews(self, obj):
group_qs = GroupObjectPermission.objects.filter(group__user=obj, permission__codename='view_submissions')
return group_qs.exists() or obj.userobjectpermission_set.filter(permission__codename='view_submissions')

def get_default_region_id(self, obj):
try:
# use the annotated value if possible
region = obj.default_region
except AttributeError:
# use computed property if region annotation does not exist
region = obj.osfstorage_region
return region._id

def get_accepted_terms_of_service(self, obj):
return bool(obj.accepted_terms_of_service)

Expand Down
21 changes: 16 additions & 5 deletions api/users/views.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from django.apps import apps
from django.db.models import F

from api.addons.views import AddonSettingsMixin
from api.base import permissions as base_permissions
Expand Down Expand Up @@ -64,21 +65,31 @@ def get_user(self, check_permissions=True):
if user._id == key:
if check_permissions:
self.check_object_permissions(self.request, user)
return user
return get_object_or_error(
OSFUser.objects.filter(id=user.id).annotate(default_region=F('addons_osfstorage_user_settings__default_region___id')).exclude(default_region=None),
request=self.request,
display_name='user'
)

if self.kwargs.get('is_embedded') is True:
if key in self.request.parents[OSFUser]:
return self.request.parents[OSFUser].get(key)

current_user = self.request.user

if key == 'me':
if isinstance(current_user, AnonymousUser):
if isinstance(current_user, AnonymousUser):
if key == 'me':
raise NotAuthenticated
else:
return self.request.user

elif key == 'me' or key == current_user._id:
return get_object_or_error(
OSFUser.objects.filter(id=current_user.id).annotate(default_region=F('addons_osfstorage_user_settings__default_region___id')).exclude(default_region=None),
request=self.request,
display_name='user'
)

obj = get_object_or_error(OSFUser, key, self.request, 'user')

if check_permissions:
# May raise a permission denied
self.check_object_permissions(self.request, obj)
Expand Down
19 changes: 19 additions & 0 deletions api_tests/nodes/views/test_node_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -1220,6 +1220,25 @@ def test_create_component_inherit_contributors_with_unregistered_contributor(
new_component.contributors
) == len(parent_project.contributors)

def test_create_project_with_region_relationship(
self, app, user_one, region, private_project, url):
private_project['data']['relationships'] = {
'region': {
'data': {
'type': 'region',
'id': region._id
}
}
}
res = app.post_json_api(
url, private_project, auth=user_one.auth
)
assert res.status_code == 201
project = AbstractNode.load(res.json['data']['id'])

node_settings = project.get_addon('osfstorage')
assert node_settings.region_id == region.id

def test_create_project_with_region_query_param(
self, app, user_one, region, private_project, url_with_region_query_param):
res = app.post_json_api(
Expand Down
6 changes: 6 additions & 0 deletions osf/models/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,12 @@ def csl_name(self, node_id=None):
'given': csl_given_name,
}

@property
def osfstorage_region(self):
from addons.osfstorage.models import Region
user_settings = self.get_addon('osfstorage')
return Region.objects.get(id=user_settings.default_region_id)

@property
def contributor_to(self):
return self.nodes.filter(is_deleted=False, type__in=['osf.node', 'osf.registration'])
Expand Down

0 comments on commit 7ed9950

Please sign in to comment.