diff --git a/admin/preprint_providers/forms.py b/admin/preprint_providers/forms.py index 06d27052098..1393aae41ef 100644 --- a/admin/preprint_providers/forms.py +++ b/admin/preprint_providers/forms.py @@ -112,11 +112,12 @@ class PreprintProviderRegisterModeratorOrAdminForm(forms.Form): """ A form that finds an existing OSF User, and grants permissions to that user so that they can use the admin app""" - def __init__(self, *args, **kwargs): - provider_id = kwargs.pop('provider_id') + def __init__(self, *args, provider_groups=None, **kwargs): super().__init__(*args, **kwargs) + + provider_groups = provider_groups or Group.objects.none() self.fields['group_perms'] = forms.ModelMultipleChoiceField( - queryset=Group.objects.filter(name__startswith=f'reviews_preprint_{provider_id}'), + queryset=provider_groups, required=False, widget=forms.CheckboxSelectMultiple ) diff --git a/admin/preprint_providers/views.py b/admin/preprint_providers/views.py index e98aed1ecfa..4c7439f4554 100644 --- a/admin/preprint_providers/views.py +++ b/admin/preprint_providers/views.py @@ -12,6 +12,7 @@ from django.contrib.auth.mixins import PermissionRequiredMixin from django.forms.models import model_to_dict from django.shortcuts import redirect, render +from django.utils.functional import cached_property from admin.base import settings from admin.base.forms import ImportFileForm @@ -459,14 +460,18 @@ class PreprintProviderRegisterModeratorOrAdmin(PermissionRequiredMixin, FormView template_name = 'preprint_providers/register_moderator_admin.html' form_class = PreprintProviderRegisterModeratorOrAdminForm + @cached_property + def target_provider(self): + return PreprintProvider.objects.get(id=self.kwargs['preprint_provider_id']) + def get_form_kwargs(self): kwargs = super().get_form_kwargs() - kwargs['provider_id'] = self.kwargs['preprint_provider_id'] + kwargs['provider_groups'] = self.target_provider.group_objects return kwargs def get_context_data(self, **kwargs): context = super().get_context_data(**kwargs) - context['provider_name'] = PreprintProvider.objects.get(id=self.kwargs['preprint_provider_id']).name + context['provider_name'] = self.target_provider.name return context def form_valid(self, form): @@ -477,13 +482,7 @@ def form_valid(self, form): raise Http404(f'OSF user with id "{user_id}" not found. Please double check.') for group in form.cleaned_data.get('group_perms'): - osf_user.groups.add(group) - split = group.name.split('_') - group_type = split[0] - if group_type == 'reviews': - provider_id = split[2] - provider = PreprintProvider.objects.get(id=provider_id) - provider.notification_subscriptions.get(event_name='new_pending_submissions').add_user_to_subscription(osf_user, 'email_transactional') + self.target_provider.add_to_group(osf_user, group) osf_user.save() messages.success(self.request, f'Permissions update successful for OSF User {osf_user.username}!') diff --git a/osf/models/mixins.py b/osf/models/mixins.py index ac8d5a094d8..b7fe97b7ece 100644 --- a/osf/models/mixins.py +++ b/osf/models/mixins.py @@ -1061,12 +1061,17 @@ def get_request_state_counts(self): return counts def add_to_group(self, user, group): + if isinstance(group, Group): + group.user_set.add(user) + elif isinstance(group, str): + self.get_group(group).user_set.add(user) + else: + raise TypeError(f"Unsupported group type: {type(group)}") + # Add default notification subscription for subscription in self.DEFAULT_SUBSCRIPTIONS: self.add_user_to_subscription(user, f'{self._id}_{subscription}') - return self.get_group(group).user_set.add(user) - def remove_from_group(self, user, group, unsubscribe=True): _group = self.get_group(group) if group == ADMIN: