From 624f1e3d920c57c1c89ecb1a89c0fa810a08ec67 Mon Sep 17 00:00:00 2001 From: Ostap Zherebetskyi Date: Thu, 3 Apr 2025 14:04:40 +0300 Subject: [PATCH 1/4] Add registration callback endpoint and tests --- addons/base/views.py | 6 +- api/registrations/urls.py | 1 + api/registrations/views.py | 53 +++++++++++- .../views/test_regisatration_callbacks.py | 82 +++++++++++++++++++ osf/models/node.py | 7 ++ 5 files changed, 143 insertions(+), 6 deletions(-) create mode 100644 api_tests/registrations/views/test_regisatration_callbacks.py diff --git a/addons/base/views.py b/addons/base/views.py index 6f22c71f3e3..a6c90860b98 100644 --- a/addons/base/views.py +++ b/addons/base/views.py @@ -431,11 +431,7 @@ def _enqueue_metrics(file_version, file_node, action, auth, from_mfr=False): def _construct_payload(auth, resource, credentials, waterbutler_settings): if isinstance(resource, Registration): - callback_url = resource.api_url_for( - 'registration_callbacks', - _absolute=True, - _internal=True - ) + callback_url = resource.callbacks_url else: callback_url = resource.api_url_for( 'create_waterbutler_log', diff --git a/api/registrations/urls.py b/api/registrations/urls.py index 66e5175b05b..232be823bb9 100644 --- a/api/registrations/urls.py +++ b/api/registrations/urls.py @@ -13,6 +13,7 @@ re_path(r'^(?P\w+)/$', views.RegistrationDetail.as_view(), name=views.RegistrationDetail.view_name), re_path(r'^(?P\w+)/bibliographic_contributors/$', views.RegistrationBibliographicContributorsList.as_view(), name=views.RegistrationBibliographicContributorsList.view_name), re_path(r'^(?P\w+)/cedar_metadata_records/$', views.RegistrationCedarMetadataRecordsList.as_view(), name=views.RegistrationCedarMetadataRecordsList.view_name), + re_path(r'^(?P\w+)/callbacks/$', views.RegistrationCallbackView.as_view(), name=views.RegistrationCallbackView.view_name), re_path(r'^(?P\w+)/children/$', views.RegistrationChildrenList.as_view(), name=views.RegistrationChildrenList.view_name), re_path(r'^(?P\w+)/comments/$', views.RegistrationCommentsList.as_view(), name=views.RegistrationCommentsList.view_name), re_path(r'^(?P\w+)/contributors/$', views.RegistrationContributorsList.as_view(), name=views.RegistrationContributorsList.view_name), diff --git a/api/registrations/views.py b/api/registrations/views.py index 29305d23ce2..47ca05438f2 100644 --- a/api/registrations/views.py +++ b/api/registrations/views.py @@ -1,7 +1,13 @@ -from rest_framework import generics, mixins, permissions as drf_permissions +from rest_framework import generics, mixins, permissions as drf_permissions, status from rest_framework.exceptions import ValidationError, NotFound, PermissionDenied +from rest_framework.response import Response +from framework.exceptions import HTTPError from framework.auth.oauth_scopes import CoreScopes +from addons.base.views import DOWNLOAD_ACTIONS +from website.archiver import signals, ARCHIVER_NETWORK_ERROR, ARCHIVER_SUCCESS, ARCHIVER_FAILURE +from website.project import signals as project_signals + from osf.models import Registration, OSFUser, RegistrationProvider, OutcomeArtifact, CedarMetadataRecord from osf.utils.permissions import WRITE_NODE from osf.utils.workflows import ApprovalStates @@ -28,6 +34,7 @@ JSONAPIMultipleRelationshipsParser, JSONAPIRelationshipParserForRegularJSON, JSONAPIMultipleRelationshipsParserForRegularJSON, + HMACSignedParser, ) from api.base.utils import ( get_user_auth, @@ -1038,3 +1045,47 @@ def get_default_queryset(self): def get_queryset(self): return self.get_queryset_from_request() + + +class RegistrationCallbackView(JSONAPIBaseView, generics.UpdateAPIView, RegistrationMixin): + permission_classes = [drf_permissions.AllowAny] + + view_category = 'registrations' + view_name = 'registration-callbacks' + + parser_classes = [HMACSignedParser] + + def update(self, request, *args, **kwargs): + registration = self.get_node() + + try: + payload = request.data + if payload.get('action', None) in DOWNLOAD_ACTIONS: + return Response({'status': 'success'}, status=status.HTTP_200_OK) + errors = payload.get('errors') + src_provider = payload['source']['provider'] + if errors: + registration.archive_job.update_target( + src_provider, + ARCHIVER_FAILURE, + errors=errors, + ) + else: + # Dataverse requires two seperate targets, one + # for draft files and one for published files + if src_provider == 'dataverse': + src_provider += '-' + (payload['destination']['name'].split(' ')[-1].lstrip('(').rstrip(')').strip()) + registration.archive_job.update_target( + src_provider, + ARCHIVER_SUCCESS, + ) + project_signals.archive_callback.send(registration) + return Response(status=status.HTTP_200_OK) + except HTTPError as e: + registration.archive_status = ARCHIVER_NETWORK_ERROR + registration.save() + signals.archive_fail.send( + registration, + errors=[str(e)] + ) + return Response(status=status.HTTP_200_OK) diff --git a/api_tests/registrations/views/test_regisatration_callbacks.py b/api_tests/registrations/views/test_regisatration_callbacks.py new file mode 100644 index 00000000000..d559fbf14b7 --- /dev/null +++ b/api_tests/registrations/views/test_regisatration_callbacks.py @@ -0,0 +1,82 @@ +import copy +import time +import pytest + +from api.base.settings.defaults import API_BASE +from osf_tests.factories import RegistrationFactory +from framework.auth import signing + + +@pytest.mark.django_db +class TestRegistrationCallbacks: + + @pytest.fixture() + def registration(self): + registration = RegistrationFactory() + return registration + + @pytest.fixture() + def url(self, registration): + return f'/{API_BASE}registrations/{registration._id}/callbacks/' + + @pytest.fixture() + def payload(self): + return { + "action": "copy", + "destination": { + "name": "Archive of OSF Storage", + }, + "errors": None, + "source": { + "provider": "osfstorage", + }, + "time": time.time() + 1000 + } + + def sign_payload(self, payload): + message, signature = signing.default_signer.sign_payload(payload) + return { + 'payload': message, + 'signature': signature, + } + + def test_registration_callback(self, app, payload, url): + data = self.sign_payload(payload) + res = app.put_json(url, data) + assert res.status_code == 200 + + def test_signature_expired(self, app, payload, url): + payload['time'] = time.time() - 100 + data = self.sign_payload(payload) + res = app.put_json(url, data, expect_errors=True) + assert res.status_code == 400 + assert res.json['errors'][0]['detail'] == 'Signature has expired' + + def test_bad_signature(self, app, payload, url): + data = self.sign_payload(payload) + data['signature'] = '1234' + res = app.put_json(url, data, expect_errors=True) + assert res.status_code == 401 + assert res.json['errors'][0]['detail'] == 'Authentication credentials were not provided.' + + def test_invalid_payload(self, app, payload, url): + payload1 = copy.deepcopy(payload) + del payload1['time'] + data = self.sign_payload(payload1) + res = app.put_json(url, data, expect_errors=True) + assert res.status_code == 400 + assert res.json['errors'][0]['detail'] == 'Invalid Payload' + + payload2 = copy.deepcopy(payload) + data = self.sign_payload(payload2) + del data['signature'] + res = app.put_json(url, data, expect_errors=True) + assert res.status_code == 400 + assert res.json['errors'][0]['detail'] == 'Invalid Payload' + + payload3 = copy.deepcopy(payload) + data = self.sign_payload(payload3) + del data['payload'] + res = app.put_json(url, data, expect_errors=True) + assert res.status_code == 400 + assert res.json['errors'][0]['detail'] == 'Invalid Payload' diff --git a/osf/models/node.py b/osf/models/node.py index d06af182e47..cf182a7c9ed 100644 --- a/osf/models/node.py +++ b/osf/models/node.py @@ -615,6 +615,10 @@ def institutions_url(self): def institutions_relationship_url(self): return self.absolute_api_v2_url + 'relationships/institutions/' + @property + def callbacks_url(self): + return self.absolute_api_v2_url + 'callbacks/' + # For Comment API compatibility @property def target_type(self): @@ -664,6 +668,9 @@ def web_url_for(self, view_name, _absolute=False, _guid=False, *args, **kwargs): def api_url_for(self, view_name, _absolute=False, *args, **kwargs): return api_url_for(view_name, pid=self._primary_key, _absolute=_absolute, *args, **kwargs) + def api_v2_url_for(self, path_str, params=None, **kwargs): + return api_url_for(path_str, params=params, **kwargs) + @property def project_or_component(self): # The distinction is drawn based on whether something has a parent node, rather than by category From 940bc3c5b9936c0f7fcadc48243dbf74c5e0c804 Mon Sep 17 00:00:00 2001 From: Ostap Zherebetskyi Date: Thu, 3 Apr 2025 14:14:28 +0300 Subject: [PATCH 2/4] add-trailing-comma double-quote-string-fixer --- api/registrations/views.py | 2 +- .../views/test_regisatration_callbacks.py | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/api/registrations/views.py b/api/registrations/views.py index 47ca05438f2..e540acd3c31 100644 --- a/api/registrations/views.py +++ b/api/registrations/views.py @@ -1086,6 +1086,6 @@ def update(self, request, *args, **kwargs): registration.save() signals.archive_fail.send( registration, - errors=[str(e)] + errors=[str(e)], ) return Response(status=status.HTTP_200_OK) diff --git a/api_tests/registrations/views/test_regisatration_callbacks.py b/api_tests/registrations/views/test_regisatration_callbacks.py index d559fbf14b7..35d65d013b6 100644 --- a/api_tests/registrations/views/test_regisatration_callbacks.py +++ b/api_tests/registrations/views/test_regisatration_callbacks.py @@ -22,15 +22,15 @@ def url(self, registration): @pytest.fixture() def payload(self): return { - "action": "copy", - "destination": { - "name": "Archive of OSF Storage", + 'action': 'copy', + 'destination': { + 'name': 'Archive of OSF Storage', }, - "errors": None, - "source": { - "provider": "osfstorage", + 'errors': None, + 'source': { + 'provider': 'osfstorage', }, - "time": time.time() + 1000 + 'time': time.time() + 1000 } def sign_payload(self, payload): From 178badb648cde2ba4b2f4f45b149db65f0f26b57 Mon Sep 17 00:00:00 2001 From: Ostap Zherebetskyi Date: Thu, 3 Apr 2025 14:39:52 +0300 Subject: [PATCH 3/4] fix test_view_classes_have_minimal_set_of_permissions_classes --- api_tests/base/test_views.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/api_tests/base/test_views.py b/api_tests/base/test_views.py index 2db3a4b65b2..b09df8d753c 100644 --- a/api_tests/base/test_views.py +++ b/api_tests/base/test_views.py @@ -18,6 +18,7 @@ MetricsOpenapiView, ) from api.users.views import ClaimUser, ResetPassword, ExternalLoginConfirmEmailView, ExternalLogin +from api.registrations.views import RegistrationCallbackView from api.wb.views import MoveFileMetadataView, CopyFileMetadataView from rest_framework.permissions import IsAuthenticatedOrReadOnly, IsAuthenticated from api.base.permissions import TokenHasScope @@ -63,6 +64,7 @@ def setUp(self): ResetPassword, ExternalLoginConfirmEmailView, ExternalLogin, + RegistrationCallbackView, ] def test_root_returns_200(self): From 054ce019631cb89ab4312514d0abe03f19e2746b Mon Sep 17 00:00:00 2001 From: Ostap Zherebetskyi Date: Thu, 3 Apr 2025 16:47:22 +0300 Subject: [PATCH 4/4] remove old Flask registration_callback code and related decorators --- osf_tests/test_archiver.py | 17 ---------------- website/archiver/decorators.py | 25 ----------------------- website/project/decorators.py | 19 ----------------- website/project/views/register.py | 34 +------------------------------ website/routes.py | 8 -------- 5 files changed, 1 insertion(+), 102 deletions(-) delete mode 100644 website/archiver/decorators.py diff --git a/osf_tests/test_archiver.py b/osf_tests/test_archiver.py index 3855d169acb..59c178b839d 100644 --- a/osf_tests/test_archiver.py +++ b/osf_tests/test_archiver.py @@ -22,7 +22,6 @@ from website.app import * # noqa: F403 from website.archiver import listeners from website.archiver.tasks import * # noqa: F403 -from website.archiver.decorators import fail_archive_on_error from osf.models import Guid, RegistrationSchema, Registration from osf.models.archive import ArchiveTarget, ArchiveJob @@ -1111,22 +1110,6 @@ def test_find_failed_registrations(self): assert pk not in failed -class TestArchiverDecorators(ArchiverTestCase): - - @mock.patch('website.archiver.signals.archive_fail.send') - def test_fail_archive_on_error(self, mock_fail): - e = HTTPError(418) - - def error(*args, **kwargs): - raise e - - func = fail_archive_on_error(error) - func(node=self.dst) - mock_fail.assert_called_with( - self.dst, - errors=[str(e)] - ) - class TestArchiverBehavior(OsfTestCase): @mock.patch('osf.models.AbstractNode.update_search') diff --git a/website/archiver/decorators.py b/website/archiver/decorators.py deleted file mode 100644 index 0d6f46bfb37..00000000000 --- a/website/archiver/decorators.py +++ /dev/null @@ -1,25 +0,0 @@ -import functools - -from framework.exceptions import HTTPError - -from website.project.decorators import _inject_nodes -from website.archiver import ARCHIVER_NETWORK_ERROR -from website.archiver import signals - - -def fail_archive_on_error(func): - - @functools.wraps(func) - def wrapped(*args, **kwargs): - try: - return func(*args, **kwargs) - except HTTPError as e: - _inject_nodes(kwargs) - registration = kwargs['node'] - registration.archive_status = ARCHIVER_NETWORK_ERROR - registration.save() - signals.archive_fail.send( - registration, - errors=[str(e)] - ) - return wrapped diff --git a/website/project/decorators.py b/website/project/decorators.py index 0e165146250..2d60be5359b 100644 --- a/website/project/decorators.py +++ b/website/project/decorators.py @@ -173,25 +173,6 @@ def wrapped(*args, **kwargs): return wrapped -def must_be_registration(func): - - @functools.wraps(func) - def wrapped(*args, **kwargs): - _inject_nodes(kwargs) - node = kwargs['node'] - - if not node.is_registration: - raise HTTPError( - http_status.HTTP_400_BAD_REQUEST, - data={ - 'message_short': 'Registered Nodes only', - 'message_long': 'This view is restricted to registered Nodes only', - } - ) - return func(*args, **kwargs) - - return wrapped - def check_can_download_preprint_file(user, node): """View helper that returns whether a given user can download unpublished preprint files. diff --git a/website/project/views/register.py b/website/project/views/register.py index 11a5da7f53c..265fda1edea 100644 --- a/website/project/views/register.py +++ b/website/project/views/register.py @@ -7,17 +7,12 @@ from framework.exceptions import HTTPError from framework.flask import redirect # VOL-aware redirect -from framework.auth.decorators import must_be_signed - -from website.archiver import ARCHIVER_SUCCESS, ARCHIVER_FAILURE - -from addons.base.views import DOWNLOAD_ACTIONS from website import settings from osf.exceptions import NodeStateError from website.project.decorators import ( must_be_valid_project, must_be_contributor_or_public, must_have_permission, must_be_contributor_and_not_group_member, - must_not_be_registration, must_be_registration, + must_not_be_registration, must_not_be_retracted_registration ) from osf import features @@ -26,12 +21,10 @@ from osf.utils.permissions import ADMIN from website import language from website.ember_osf_web.decorators import ember_flag_is_active -from website.project import signals as project_signals from website.project.metadata.schemas import _id_to_name from website import util from website.project.metadata.utils import serialize_meta_schema from website.project.model import has_anonymous_link -from website.archiver.decorators import fail_archive_on_error from .node import _view_project from api.waffle.utils import flag_is_active @@ -228,28 +221,3 @@ def get_referent_by_identifier(category, value): if identifier.referent.url: return redirect(identifier.referent.url) raise HTTPError(http_status.HTTP_404_NOT_FOUND) - -@fail_archive_on_error -@must_be_signed -@must_be_registration -def registration_callbacks(node, payload, *args, **kwargs): - if payload.get('action', None) in DOWNLOAD_ACTIONS: - return {'status': 'success'} - errors = payload.get('errors') - src_provider = payload['source']['provider'] - if errors: - node.archive_job.update_target( - src_provider, - ARCHIVER_FAILURE, - errors=errors, - ) - else: - # Dataverse requires two seperate targets, one - # for draft files and one for published files - if src_provider == 'dataverse': - src_provider += '-' + (payload['destination']['name'].split(' ')[-1].lstrip('(').rstrip(')').strip()) - node.archive_job.update_target( - src_provider, - ARCHIVER_SUCCESS, - ) - project_signals.archive_callback.send(node) diff --git a/website/routes.py b/website/routes.py index 227d68302e3..8e2ab328a72 100644 --- a/website/routes.py +++ b/website/routes.py @@ -1715,14 +1715,6 @@ def make_url_map(app): addon_views.create_waterbutler_log, json_renderer, ), - Rule( - [ - '/registration//callbacks/', - ], - 'put', - project_views.register.registration_callbacks, - json_renderer, - ), Rule( '/settings/addons/', 'post',