diff --git a/rest_registration/api/views/register.py b/rest_registration/api/views/register.py index fa18178..db68169 100644 --- a/rest_registration/api/views/register.py +++ b/rest_registration/api/views/register.py @@ -13,7 +13,11 @@ from rest_registration.notifications import send_verification_notification from rest_registration.settings import registration_settings from rest_registration.utils.responses import get_ok_response -from rest_registration.utils.users import get_user_by_id, get_user_setting +from rest_registration.utils.users import ( + get_user_by_verification_id, + get_user_setting, + get_user_verification_id +) from rest_registration.utils.verification import verify_signer_or_bad_request from rest_registration.verification import URLParamsSigner @@ -30,8 +34,8 @@ def get_valid_period(self): def _calculate_salt(self, data): if registration_settings.REGISTER_VERIFICATION_ONE_TIME_USE: - user_id = data['user_id'] - user = get_user_by_id(user_id, require_verified=False) + user = get_user_by_verification_id( + data['user_id'], require_verified=False) # Use current user verification flag as a part of the salt. # If the verification flag gets changed, then assume that # the change was caused by previous verification and the signature @@ -77,7 +81,7 @@ def register(request): if registration_settings.REGISTER_VERIFICATION_ENABLED: signer = RegisterSigner({ - 'user_id': user.pk, + 'user_id': get_user_verification_id(user), }, request=request) template_config = ( registration_settings.REGISTER_VERIFICATION_EMAIL_TEMPLATES) @@ -117,7 +121,7 @@ def process_verify_registration_data(input_data): verify_signer_or_bad_request(signer) verification_flag_field = get_user_setting('VERIFICATION_FLAG_FIELD') - user = get_user_by_id(data['user_id'], require_verified=False) + user = get_user_by_verification_id(data['user_id'], require_verified=False) setattr(user, verification_flag_field, True) user.save() diff --git a/rest_registration/api/views/register_email.py b/rest_registration/api/views/register_email.py index fe9ce74..6d55cb0 100644 --- a/rest_registration/api/views/register_email.py +++ b/rest_registration/api/views/register_email.py @@ -10,7 +10,11 @@ from rest_registration.notifications import send_verification_notification from rest_registration.settings import registration_settings from rest_registration.utils.responses import get_ok_response -from rest_registration.utils.users import get_user_by_id, get_user_setting +from rest_registration.utils.users import ( + get_user_by_verification_id, + get_user_setting, + get_user_verification_id +) from rest_registration.utils.verification import verify_signer_or_bad_request from rest_registration.verification import URLParamsSigner @@ -46,7 +50,7 @@ def register_email(request): registration_settings.REGISTER_EMAIL_VERIFICATION_EMAIL_TEMPLATES) if registration_settings.REGISTER_EMAIL_VERIFICATION_ENABLED: signer = RegisterEmailSigner({ - 'user_id': user.pk, + 'user_id': get_user_verification_id(user), 'email': email, }, request=request) send_verification_notification( @@ -88,6 +92,6 @@ def process_verify_email_data(input_data): verify_signer_or_bad_request(signer) email_field = get_user_setting('EMAIL_FIELD') - user = get_user_by_id(data['user_id']) + user = get_user_by_verification_id(data['user_id']) setattr(user, email_field, data['email']) user.save() diff --git a/rest_registration/api/views/reset_password.py b/rest_registration/api/views/reset_password.py index 12c19f3..7ce8279 100644 --- a/rest_registration/api/views/reset_password.py +++ b/rest_registration/api/views/reset_password.py @@ -13,7 +13,10 @@ from rest_registration.notifications import send_verification_notification from rest_registration.settings import registration_settings from rest_registration.utils.responses import get_ok_response -from rest_registration.utils.users import get_user_by_id +from rest_registration.utils.users import ( + get_user_by_verification_id, + get_user_verification_id +) from rest_registration.utils.verification import verify_signer_or_bad_request from rest_registration.verification import URLParamsSigner @@ -30,8 +33,8 @@ def get_valid_period(self): def _calculate_salt(self, data): if registration_settings.RESET_PASSWORD_VERIFICATION_ONE_TIME_USE: - user_id = data['user_id'] - user = get_user_by_id(user_id, require_verified=False) + user = get_user_by_verification_id( + data['user_id'], require_verified=False) # Use current user password hash as a part of the salt. # If the password gets changed, then assume that the change # was caused by previous password reset and the signature @@ -61,7 +64,7 @@ def send_reset_password_link(request): if not user: raise UserNotFound() signer = ResetPasswordSigner({ - 'user_id': user.pk, + 'user_id': get_user_verification_id(user), }, request=request) template_config = ( @@ -100,7 +103,7 @@ def process_reset_password_data(input_data): signer = ResetPasswordSigner(data) verify_signer_or_bad_request(signer) - user = get_user_by_id(data['user_id'], require_verified=False) + user = get_user_by_verification_id(data['user_id'], require_verified=False) try: validate_password(password, user=user) except ValidationError as exc: diff --git a/rest_registration/settings_fields.py b/rest_registration/settings_fields.py index 1aad5fb..3dca958 100644 --- a/rest_registration/settings_fields.py +++ b/rest_registration/settings_fields.py @@ -43,6 +43,17 @@ def __new__(cls, name, *, default=None, help=None, import_string=False): 'USER_EMAIL_FIELD', default='email', ), + Field( + 'USER_VERIFICATION_ID_FIELD', + default='pk', + help=dedent("""\ + Field used in verification, as part of signed data. + + The given field should uniquely identify the user. This means that + using any user field which could change over time + (``email``, ``username``) is NOT recommended. + """), + ), Field( 'USER_VERIFICATION_FLAG_FIELD', default='is_active', diff --git a/rest_registration/utils/users.py b/rest_registration/utils/users.py index 3e8bdd9..7cebc8d 100644 --- a/rest_registration/utils/users.py +++ b/rest_registration/utils/users.py @@ -56,10 +56,18 @@ def authenticate_by_login_and_password_or_none(login, password): return user -def get_user_by_id(user_id, default=_RAISE_EXCEPTION, require_verified=True): +def get_user_verification_id(user): + verification_id_field = get_user_setting('VERIFICATION_ID_FIELD') + return getattr(user, verification_id_field) + + +def get_user_by_verification_id( + user_verification_id, default=_RAISE_EXCEPTION, require_verified=True): + verification_id_field = get_user_setting('VERIFICATION_ID_FIELD') return get_user_by_lookup_dict({ - 'pk': user_id, - }, require_verified=require_verified) + verification_id_field: user_verification_id}, + default=default, + require_verified=require_verified) def get_user_by_lookup_dict( diff --git a/tests/api/test_register.py b/tests/api/test_register.py index 90e6919..5317828 100644 --- a/tests/api/test_register.py +++ b/tests/api/test_register.py @@ -143,6 +143,45 @@ def test_register_ok(self): signer = RegisterSigner(verification_data) signer.verify() + @override_settings( + REST_REGISTRATION=shallow_merge_dicts( + REST_REGISTRATION_WITH_VERIFICATION, { + 'USER_VERIFICATION_ID_FIELD': 'username', + }, + ), + ) + def test_register_with_username_as_verification_id_ok(self): + # Using username is not recommended if it can change for a given user. + data = self._get_register_user_data(password='testpassword') + request = self.create_post_request(data) + with self.assert_one_mail_sent() as sent_emails, self.timer() as timer: + response = self.view_func(request) + self.assert_valid_response(response, status.HTTP_201_CREATED) + user_id = response.data['id'] + # Check database state. + user = self.user_class.objects.get(id=user_id) + self.assertEqual(user.username, data['username']) + self.assertTrue(user.check_password(data['password'])) + self.assertFalse(user.is_active) + # Check verification e-mail. + sent_email = sent_emails[0] + self.assertEqual(sent_email.from_email, VERIFICATION_FROM_EMAIL) + self.assertListEqual(sent_email.to, [data['email']]) + url = self.assert_one_url_line_in_text(sent_email.body) + + verification_data = self.assert_valid_verification_url( + url, + expected_path=REGISTER_VERIFICATION_URL, + expected_fields={'signature', 'user_id', 'timestamp'}, + ) + user_verification_id = verification_data['user_id'] + self.assertEqual(user_verification_id, user.username) + url_sig_timestamp = int(verification_data['timestamp']) + self.assertGreaterEqual(url_sig_timestamp, timer.start_time) + self.assertLessEqual(url_sig_timestamp, timer.end_time) + signer = RegisterSigner(verification_data) + signer.verify() + @override_settings( REST_REGISTRATION=shallow_merge_dicts( REST_REGISTRATION_WITH_VERIFICATION, { @@ -351,8 +390,10 @@ def prepare_user(self): self.assertFalse(user.is_active) return user - def prepare_request(self, user, session=False): - signer = RegisterSigner({'user_id': user.pk}) + def prepare_request(self, user, session=False, data_to_sign=None): + if data_to_sign is None: + data_to_sign = {'user_id': user.pk} + signer = RegisterSigner(data_to_sign) data = signer.get_signed_data() request = self.create_post_request(data) if session: @@ -372,6 +413,22 @@ def test_verify_ok(self): user.refresh_from_db() self.assertTrue(user.is_active) + @override_settings( + REST_REGISTRATION=shallow_merge_dicts( + REST_REGISTRATION_WITH_VERIFICATION, { + 'USER_VERIFICATION_ID_FIELD': 'username', + }, + ), + ) + def test_verify_with_username_as_verification_id_ok(self): + user = self.prepare_user() + request = self.prepare_request( + user, data_to_sign={'user_id': user.username}) + response = self.view_func(request) + self.assert_valid_response(response, status.HTTP_200_OK) + user.refresh_from_db() + self.assertTrue(user.is_active) + @override_settings(REST_REGISTRATION=REST_REGISTRATION_WITH_VERIFICATION) def test_verify_ok_idempotent(self): user = self.prepare_user()