Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions addons/base/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,11 +431,7 @@ def _enqueue_metrics(file_version, file_node, action, auth, from_mfr=False):
def _construct_payload(auth, resource, credentials, waterbutler_settings):

if isinstance(resource, Registration):
callback_url = resource.api_url_for(
'registration_callbacks',
_absolute=True,
_internal=True
)
callback_url = resource.callbacks_url
else:
callback_url = resource.api_url_for(
'create_waterbutler_log',
Expand Down
1 change: 1 addition & 0 deletions api/registrations/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
re_path(r'^(?P<node_id>\w+)/$', views.RegistrationDetail.as_view(), name=views.RegistrationDetail.view_name),
re_path(r'^(?P<node_id>\w+)/bibliographic_contributors/$', views.RegistrationBibliographicContributorsList.as_view(), name=views.RegistrationBibliographicContributorsList.view_name),
re_path(r'^(?P<node_id>\w+)/cedar_metadata_records/$', views.RegistrationCedarMetadataRecordsList.as_view(), name=views.RegistrationCedarMetadataRecordsList.view_name),
re_path(r'^(?P<node_id>\w+)/callbacks/$', views.RegistrationCallbackView.as_view(), name=views.RegistrationCallbackView.view_name),
re_path(r'^(?P<node_id>\w+)/children/$', views.RegistrationChildrenList.as_view(), name=views.RegistrationChildrenList.view_name),
re_path(r'^(?P<node_id>\w+)/comments/$', views.RegistrationCommentsList.as_view(), name=views.RegistrationCommentsList.view_name),
re_path(r'^(?P<node_id>\w+)/contributors/$', views.RegistrationContributorsList.as_view(), name=views.RegistrationContributorsList.view_name),
Expand Down
53 changes: 52 additions & 1 deletion api/registrations/views.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from rest_framework import generics, mixins, permissions as drf_permissions
from rest_framework import generics, mixins, permissions as drf_permissions, status
from rest_framework.exceptions import ValidationError, NotFound, PermissionDenied
from rest_framework.response import Response
from framework.exceptions import HTTPError
from framework.auth.oauth_scopes import CoreScopes

from addons.base.views import DOWNLOAD_ACTIONS
from website.archiver import signals, ARCHIVER_NETWORK_ERROR, ARCHIVER_SUCCESS, ARCHIVER_FAILURE
from website.project import signals as project_signals

from osf.models import Registration, OSFUser, RegistrationProvider, OutcomeArtifact, CedarMetadataRecord
from osf.utils.permissions import WRITE_NODE
from osf.utils.workflows import ApprovalStates
Expand All @@ -28,6 +34,7 @@
JSONAPIMultipleRelationshipsParser,
JSONAPIRelationshipParserForRegularJSON,
JSONAPIMultipleRelationshipsParserForRegularJSON,
HMACSignedParser,
)
from api.base.utils import (
get_user_auth,
Expand Down Expand Up @@ -1038,3 +1045,47 @@ def get_default_queryset(self):

def get_queryset(self):
return self.get_queryset_from_request()


class RegistrationCallbackView(JSONAPIBaseView, generics.UpdateAPIView, RegistrationMixin):
permission_classes = [drf_permissions.AllowAny]

view_category = 'registrations'
view_name = 'registration-callbacks'

parser_classes = [HMACSignedParser]

def update(self, request, *args, **kwargs):
registration = self.get_node()

try:
payload = request.data
if payload.get('action', None) in DOWNLOAD_ACTIONS:
return Response({'status': 'success'}, status=status.HTTP_200_OK)
errors = payload.get('errors')
src_provider = payload['source']['provider']
if errors:
registration.archive_job.update_target(
src_provider,
ARCHIVER_FAILURE,
errors=errors,
)
else:
# Dataverse requires two seperate targets, one
# for draft files and one for published files
if src_provider == 'dataverse':
src_provider += '-' + (payload['destination']['name'].split(' ')[-1].lstrip('(').rstrip(')').strip())
registration.archive_job.update_target(
src_provider,
ARCHIVER_SUCCESS,
)
project_signals.archive_callback.send(registration)
return Response(status=status.HTTP_200_OK)
except HTTPError as e:
registration.archive_status = ARCHIVER_NETWORK_ERROR
registration.save()
signals.archive_fail.send(
registration,
errors=[str(e)],
)
return Response(status=status.HTTP_200_OK)
2 changes: 2 additions & 0 deletions api_tests/base/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
MetricsOpenapiView,
)
from api.users.views import ClaimUser, ResetPassword, ExternalLoginConfirmEmailView, ExternalLogin
from api.registrations.views import RegistrationCallbackView
from api.wb.views import MoveFileMetadataView, CopyFileMetadataView
from rest_framework.permissions import IsAuthenticatedOrReadOnly, IsAuthenticated
from api.base.permissions import TokenHasScope
Expand Down Expand Up @@ -63,6 +64,7 @@ def setUp(self):
ResetPassword,
ExternalLoginConfirmEmailView,
ExternalLogin,
RegistrationCallbackView,
]

def test_root_returns_200(self):
Expand Down
82 changes: 82 additions & 0 deletions api_tests/registrations/views/test_regisatration_callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import copy
import time
import pytest

from api.base.settings.defaults import API_BASE
from osf_tests.factories import RegistrationFactory
from framework.auth import signing


@pytest.mark.django_db
class TestRegistrationCallbacks:

@pytest.fixture()
def registration(self):
registration = RegistrationFactory()
return registration

@pytest.fixture()
def url(self, registration):
return f'/{API_BASE}registrations/{registration._id}/callbacks/'

@pytest.fixture()
def payload(self):
return {
'action': 'copy',
'destination': {
'name': 'Archive of OSF Storage',
},
'errors': None,
'source': {
'provider': 'osfstorage',
},
'time': time.time() + 1000
}

def sign_payload(self, payload):
message, signature = signing.default_signer.sign_payload(payload)
return {
'payload': message,
'signature': signature,
}

def test_registration_callback(self, app, payload, url):
data = self.sign_payload(payload)
res = app.put_json(url, data)
assert res.status_code == 200

def test_signature_expired(self, app, payload, url):
payload['time'] = time.time() - 100
data = self.sign_payload(payload)
res = app.put_json(url, data, expect_errors=True)
assert res.status_code == 400
assert res.json['errors'][0]['detail'] == 'Signature has expired'

def test_bad_signature(self, app, payload, url):
data = self.sign_payload(payload)
data['signature'] = '1234'
res = app.put_json(url, data, expect_errors=True)
assert res.status_code == 401
assert res.json['errors'][0]['detail'] == 'Authentication credentials were not provided.'

def test_invalid_payload(self, app, payload, url):
payload1 = copy.deepcopy(payload)
del payload1['time']
data = self.sign_payload(payload1)
res = app.put_json(url, data, expect_errors=True)
assert res.status_code == 400
assert res.json['errors'][0]['detail'] == 'Invalid Payload'

payload2 = copy.deepcopy(payload)
data = self.sign_payload(payload2)
del data['signature']
res = app.put_json(url, data, expect_errors=True)
assert res.status_code == 400
assert res.json['errors'][0]['detail'] == 'Invalid Payload'

payload3 = copy.deepcopy(payload)
data = self.sign_payload(payload3)
del data['payload']
res = app.put_json(url, data, expect_errors=True)
assert res.status_code == 400
assert res.json['errors'][0]['detail'] == 'Invalid Payload'
7 changes: 7 additions & 0 deletions osf/models/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,10 @@ def institutions_url(self):
def institutions_relationship_url(self):
return self.absolute_api_v2_url + 'relationships/institutions/'

@property
def callbacks_url(self):
return self.absolute_api_v2_url + 'callbacks/'

# For Comment API compatibility
@property
def target_type(self):
Expand Down Expand Up @@ -664,6 +668,9 @@ def web_url_for(self, view_name, _absolute=False, _guid=False, *args, **kwargs):
def api_url_for(self, view_name, _absolute=False, *args, **kwargs):
return api_url_for(view_name, pid=self._primary_key, _absolute=_absolute, *args, **kwargs)

def api_v2_url_for(self, path_str, params=None, **kwargs):
return api_url_for(path_str, params=params, **kwargs)

@property
def project_or_component(self):
# The distinction is drawn based on whether something has a parent node, rather than by category
Expand Down
17 changes: 0 additions & 17 deletions osf_tests/test_archiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from website.app import * # noqa: F403
from website.archiver import listeners
from website.archiver.tasks import * # noqa: F403
from website.archiver.decorators import fail_archive_on_error

from osf.models import Guid, RegistrationSchema, Registration
from osf.models.archive import ArchiveTarget, ArchiveJob
Expand Down Expand Up @@ -1111,22 +1110,6 @@ def test_find_failed_registrations(self):
assert pk not in failed


class TestArchiverDecorators(ArchiverTestCase):

@mock.patch('website.archiver.signals.archive_fail.send')
def test_fail_archive_on_error(self, mock_fail):
e = HTTPError(418)

def error(*args, **kwargs):
raise e

func = fail_archive_on_error(error)
func(node=self.dst)
mock_fail.assert_called_with(
self.dst,
errors=[str(e)]
)

class TestArchiverBehavior(OsfTestCase):

@mock.patch('osf.models.AbstractNode.update_search')
Expand Down
25 changes: 0 additions & 25 deletions website/archiver/decorators.py

This file was deleted.

19 changes: 0 additions & 19 deletions website/project/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,25 +173,6 @@ def wrapped(*args, **kwargs):

return wrapped

def must_be_registration(func):

@functools.wraps(func)
def wrapped(*args, **kwargs):
_inject_nodes(kwargs)
node = kwargs['node']

if not node.is_registration:
raise HTTPError(
http_status.HTTP_400_BAD_REQUEST,
data={
'message_short': 'Registered Nodes only',
'message_long': 'This view is restricted to registered Nodes only',
}
)
return func(*args, **kwargs)

return wrapped


def check_can_download_preprint_file(user, node):
"""View helper that returns whether a given user can download unpublished preprint files.
Expand Down
34 changes: 1 addition & 33 deletions website/project/views/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,12 @@
from framework.exceptions import HTTPError
from framework.flask import redirect # VOL-aware redirect

from framework.auth.decorators import must_be_signed

from website.archiver import ARCHIVER_SUCCESS, ARCHIVER_FAILURE

from addons.base.views import DOWNLOAD_ACTIONS
from website import settings
from osf.exceptions import NodeStateError
from website.project.decorators import (
must_be_valid_project, must_be_contributor_or_public,
must_have_permission, must_be_contributor_and_not_group_member,
must_not_be_registration, must_be_registration,
must_not_be_registration,
must_not_be_retracted_registration
)
from osf import features
Expand All @@ -26,12 +21,10 @@
from osf.utils.permissions import ADMIN
from website import language
from website.ember_osf_web.decorators import ember_flag_is_active
from website.project import signals as project_signals
from website.project.metadata.schemas import _id_to_name
from website import util
from website.project.metadata.utils import serialize_meta_schema
from website.project.model import has_anonymous_link
from website.archiver.decorators import fail_archive_on_error

from .node import _view_project
from api.waffle.utils import flag_is_active
Expand Down Expand Up @@ -228,28 +221,3 @@ def get_referent_by_identifier(category, value):
if identifier.referent.url:
return redirect(identifier.referent.url)
raise HTTPError(http_status.HTTP_404_NOT_FOUND)

@fail_archive_on_error
@must_be_signed
@must_be_registration
def registration_callbacks(node, payload, *args, **kwargs):
if payload.get('action', None) in DOWNLOAD_ACTIONS:
return {'status': 'success'}
errors = payload.get('errors')
src_provider = payload['source']['provider']
if errors:
node.archive_job.update_target(
src_provider,
ARCHIVER_FAILURE,
errors=errors,
)
else:
# Dataverse requires two seperate targets, one
# for draft files and one for published files
if src_provider == 'dataverse':
src_provider += '-' + (payload['destination']['name'].split(' ')[-1].lstrip('(').rstrip(')').strip())
node.archive_job.update_target(
src_provider,
ARCHIVER_SUCCESS,
)
project_signals.archive_callback.send(node)
8 changes: 0 additions & 8 deletions website/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1715,14 +1715,6 @@ def make_url_map(app):
addon_views.create_waterbutler_log,
json_renderer,
),
Rule(
[
'/registration/<pid>/callbacks/',
],
'put',
project_views.register.registration_callbacks,
json_renderer,
),
Rule(
'/settings/addons/',
'post',
Expand Down
Loading