diff --git a/rest_registration/api/views/register.py b/rest_registration/api/views/register.py index fa18178..db68169 100644 --- a/rest_registration/api/views/register.py +++ b/rest_registration/api/views/register.py @@ -13,7 +13,11 @@ from rest_registration.notifications import send_verification_notification from rest_registration.settings import registration_settings from rest_registration.utils.responses import get_ok_response -from rest_registration.utils.users import get_user_by_id, get_user_setting +from rest_registration.utils.users import ( + get_user_by_verification_id, + get_user_setting, + get_user_verification_id +) from rest_registration.utils.verification import verify_signer_or_bad_request from rest_registration.verification import URLParamsSigner @@ -30,8 +34,8 @@ def get_valid_period(self): def _calculate_salt(self, data): if registration_settings.REGISTER_VERIFICATION_ONE_TIME_USE: - user_id = data['user_id'] - user = get_user_by_id(user_id, require_verified=False) + user = get_user_by_verification_id( + data['user_id'], require_verified=False) # Use current user verification flag as a part of the salt. # If the verification flag gets changed, then assume that # the change was caused by previous verification and the signature @@ -77,7 +81,7 @@ def register(request): if registration_settings.REGISTER_VERIFICATION_ENABLED: signer = RegisterSigner({ - 'user_id': user.pk, + 'user_id': get_user_verification_id(user), }, request=request) template_config = ( registration_settings.REGISTER_VERIFICATION_EMAIL_TEMPLATES) @@ -117,7 +121,7 @@ def process_verify_registration_data(input_data): verify_signer_or_bad_request(signer) verification_flag_field = get_user_setting('VERIFICATION_FLAG_FIELD') - user = get_user_by_id(data['user_id'], require_verified=False) + user = get_user_by_verification_id(data['user_id'], require_verified=False) setattr(user, verification_flag_field, True) user.save() diff --git a/rest_registration/api/views/register_email.py b/rest_registration/api/views/register_email.py index fe9ce74..6d55cb0 100644 --- a/rest_registration/api/views/register_email.py +++ b/rest_registration/api/views/register_email.py @@ -10,7 +10,11 @@ from rest_registration.notifications import send_verification_notification from rest_registration.settings import registration_settings from rest_registration.utils.responses import get_ok_response -from rest_registration.utils.users import get_user_by_id, get_user_setting +from rest_registration.utils.users import ( + get_user_by_verification_id, + get_user_setting, + get_user_verification_id +) from rest_registration.utils.verification import verify_signer_or_bad_request from rest_registration.verification import URLParamsSigner @@ -46,7 +50,7 @@ def register_email(request): registration_settings.REGISTER_EMAIL_VERIFICATION_EMAIL_TEMPLATES) if registration_settings.REGISTER_EMAIL_VERIFICATION_ENABLED: signer = RegisterEmailSigner({ - 'user_id': user.pk, + 'user_id': get_user_verification_id(user), 'email': email, }, request=request) send_verification_notification( @@ -88,6 +92,6 @@ def process_verify_email_data(input_data): verify_signer_or_bad_request(signer) email_field = get_user_setting('EMAIL_FIELD') - user = get_user_by_id(data['user_id']) + user = get_user_by_verification_id(data['user_id']) setattr(user, email_field, data['email']) user.save() diff --git a/rest_registration/api/views/reset_password.py b/rest_registration/api/views/reset_password.py index 04ff60a..7ce8279 100644 --- a/rest_registration/api/views/reset_password.py +++ b/rest_registration/api/views/reset_password.py @@ -13,7 +13,10 @@ from rest_registration.notifications import send_verification_notification from rest_registration.settings import registration_settings from rest_registration.utils.responses import get_ok_response -from rest_registration.utils.users import get_user_by_id +from rest_registration.utils.users import ( + get_user_by_verification_id, + get_user_verification_id +) from rest_registration.utils.verification import verify_signer_or_bad_request from rest_registration.verification import URLParamsSigner @@ -30,8 +33,8 @@ def get_valid_period(self): def _calculate_salt(self, data): if registration_settings.RESET_PASSWORD_VERIFICATION_ONE_TIME_USE: - user_id = data['user_id'] - user = get_user_by_id(user_id) + user = get_user_by_verification_id( + data['user_id'], require_verified=False) # Use current user password hash as a part of the salt. # If the password gets changed, then assume that the change # was caused by previous password reset and the signature @@ -61,7 +64,7 @@ def send_reset_password_link(request): if not user: raise UserNotFound() signer = ResetPasswordSigner({ - 'user_id': user.pk, + 'user_id': get_user_verification_id(user), }, request=request) template_config = ( @@ -100,7 +103,7 @@ def process_reset_password_data(input_data): signer = ResetPasswordSigner(data) verify_signer_or_bad_request(signer) - user = get_user_by_id(data['user_id'], require_verified=False) + user = get_user_by_verification_id(data['user_id'], require_verified=False) try: validate_password(password, user=user) except ValidationError as exc: diff --git a/rest_registration/settings_fields.py b/rest_registration/settings_fields.py index 1aad5fb..3dca958 100644 --- a/rest_registration/settings_fields.py +++ b/rest_registration/settings_fields.py @@ -43,6 +43,17 @@ def __new__(cls, name, *, default=None, help=None, import_string=False): 'USER_EMAIL_FIELD', default='email', ), + Field( + 'USER_VERIFICATION_ID_FIELD', + default='pk', + help=dedent("""\ + Field used in verification, as part of signed data. + + The given field should uniquely identify the user. This means that + using any user field which could change over time + (``email``, ``username``) is NOT recommended. + """), + ), Field( 'USER_VERIFICATION_FLAG_FIELD', default='is_active', diff --git a/rest_registration/utils/users.py b/rest_registration/utils/users.py index 3e8bdd9..7cebc8d 100644 --- a/rest_registration/utils/users.py +++ b/rest_registration/utils/users.py @@ -56,10 +56,18 @@ def authenticate_by_login_and_password_or_none(login, password): return user -def get_user_by_id(user_id, default=_RAISE_EXCEPTION, require_verified=True): +def get_user_verification_id(user): + verification_id_field = get_user_setting('VERIFICATION_ID_FIELD') + return getattr(user, verification_id_field) + + +def get_user_by_verification_id( + user_verification_id, default=_RAISE_EXCEPTION, require_verified=True): + verification_id_field = get_user_setting('VERIFICATION_ID_FIELD') return get_user_by_lookup_dict({ - 'pk': user_id, - }, require_verified=require_verified) + verification_id_field: user_verification_id}, + default=default, + require_verified=require_verified) def get_user_by_lookup_dict(