diff --git a/framework/auth/core.py b/framework/auth/core.py index 16bd3224f9c9..a62feb30fe38 100644 --- a/framework/auth/core.py +++ b/framework/auth/core.py @@ -115,9 +115,14 @@ def validate_user_with_verification_key(username=None, verification_key=None): return None user_obj = get_user(username=username) if user_obj: - if user_obj.verification_key_v2['token'] == verification_key: - if user_obj.verification_key_v2['expires'] > dt.datetime.utcnow(): - return user_obj + try: + if user_obj.verification_key_v2: + if user_obj.verification_key_v2['token'] == verification_key: + if user_obj.verification_key_v2['expires'] > dt.datetime.utcnow(): + return user_obj + except AttributeError: + # if user does not have verification_key_v2 + return None return None diff --git a/tests/factories.py b/tests/factories.py index 153c439f65d6..85e8079a47c7 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -102,6 +102,7 @@ class Meta: merged_by = None email_verifications = {} verification_key = None + verification_key_v2 = {} @post_generation def set_names(self, create, extracted): diff --git a/tests/webtest_tests.py b/tests/webtest_tests.py index 3bca8b9654b3..3ce4fb5f0bcc 100644 --- a/tests/webtest_tests.py +++ b/tests/webtest_tests.py @@ -1,6 +1,7 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- """Functional tests using WebTest.""" +import datetime as dt import httplib as http import logging import unittest @@ -14,7 +15,7 @@ from framework.auth import cas from framework.auth import exceptions as auth_exc from framework.auth.core import Auth -from framework.auth.core import generate_verification_key +from framework.auth.core import generate_verification_key, generate_verification_key_v2 from tests.base import OsfTestCase from tests.base import fake from tests.factories import (UserFactory, AuthUserFactory, ProjectFactory, WatchConfigFactory, NodeFactory, @@ -1009,23 +1010,47 @@ def setUp(self): super(TestResetPassword, self).setUp() self.user = AuthUserFactory() self.another_user = AuthUserFactory() - self.osf_key = generate_verification_key() - self.user.verification_key = self.osf_key + self.osf_key_v2 = generate_verification_key_v2() + self.user.verification_key_v2 = self.osf_key_v2 + self.user.verification_key = None self.user.save() - self.cas_key = None - self.get_url = web_url_for('reset_password_get', verification_key=self.osf_key) - self.get_url_invalid_key = web_url_for('reset_password_get', verification_key=generate_verification_key()) + self.get_url = web_url_for( + 'reset_password_get', + username=self.user.username, + verification_key=self.osf_key_v2['token'] + ) + self.get_url_invalid_key = web_url_for( + 'reset_password_get', + username=self.user.username, + verification_key=generate_verification_key() + ) + self.get_url_invalid_user = web_url_for( + 'reset_password_get', + username=self.another_user.username, + verification_key=self.osf_key_v2['token'] + ) # load reset password page if verification_key is valid def test_reset_password_view_returns_200(self): res = self.app.get(self.get_url) assert_equal(res.status_code, 200) - # raise http 400 error if verification_key(OSF) is invalid + # raise http 400 error if: + # verification_key_v2 is invalid, or + # verification_key_v2 has expired, or + # user is invalid def test_reset_password_view_raises_400(self): res = self.app.get(self.get_url_invalid_key, expect_errors=True) assert_equal(res.status_code, 400) + res = self.app.get(self.get_url_invalid_user, expect_errors=True) + assert_equal(res.status_code, 400) + + self.user.verification_key_v2['expires'] = dt.datetime.utcnow() + self.user.save() + res = self.app.get(self.get_url, expect_errors=True) + assert_equal(res.status_code, 400) + # successfully reset password if osf verification_key(OSF) is valid and form is valid @mock.patch('framework.auth.cas.CasClient.service_validate') def test_can_reset_password_if_form_success(self, mock_service_validate): @@ -1036,15 +1061,16 @@ def test_can_reset_password_if_form_success(self, mock_service_validate): form['password2'] = 'newpassword' res = form.submit() - # check request URL is /resetpassword with verification_key(OSF) + # check request URL is /resetpassword with username and new verification_key_v2 token request_url_path = res.request.path assert_in('resetpassword', request_url_path) - assert_in(self.user.verification_key, request_url_path) + assert_in(self.user.username, request_url_path) + assert_not_in(self.user.verification_key_v2['token'], request_url_path) - # check verification_key(OSF) is destroyed and a new verification_key(CAS) is in place + # check verification_key_v2 for OSF is destroyed and verification_key for CAS is in place self.user.reload() - self.cas_key = self.user.verification_key - assert_not_equal(self.cas_key, self.osf_key) + assert_equal(self.user.verification_key_v2, {}) + assert_not_equal(self.user.verification_key, None) # check redirection to CAS login with username and the new verification_key(CAS) assert_equal(res.status_code, 302) @@ -1057,7 +1083,7 @@ def test_can_reset_password_if_form_success(self, mock_service_validate): self.user.reload() assert_true(self.user.check_password('newpassword')) - # check if verification_key(CAS) is destroyed + # check if verification_key is destroyed after service validation mock_service_validate.return_value = cas.CasResponse( authenticated=True, user=self.user._primary_key, @@ -1065,8 +1091,8 @@ def test_can_reset_password_if_form_success(self, mock_service_validate): ) ticket = fake.md5() service_url = 'http://accounts.osf.io/?ticket=' + ticket - resp = cas.make_response_from_ticket(ticket, service_url) - assert_not_equal(self.user.verification_key, self.cas_key) + cas.make_response_from_ticket(ticket, service_url) + assert_equal(self.user.verification_key, None) # logged-in user should be automatically logged out upon before reset password def test_reset_password_logs_out_user(self):