diff --git a/tests/api/test_register.py b/tests/api/test_register.py index 29ba3e3..90e6919 100644 --- a/tests/api/test_register.py +++ b/tests/api/test_register.py @@ -1,4 +1,3 @@ -import math import time from unittest import mock from unittest.mock import patch @@ -116,11 +115,9 @@ class RegisterViewTestCase(APIViewTestCase): def test_register_ok(self): data = self._get_register_user_data(password='testpassword') request = self.create_post_request(data) - time_before = math.floor(time.time()) - with self.assert_one_mail_sent() as sent_emails: + with self.assert_one_mail_sent() as sent_emails, self.timer() as timer: response = self.view_func(request) - time_after = math.ceil(time.time()) - self.assert_valid_response(response, status.HTTP_201_CREATED) + self.assert_valid_response(response, status.HTTP_201_CREATED) user_id = response.data['id'] # Check database state. user = self.user_class.objects.get(id=user_id) @@ -141,8 +138,8 @@ def test_register_ok(self): url_user_id = int(verification_data['user_id']) self.assertEqual(url_user_id, user_id) url_sig_timestamp = int(verification_data['timestamp']) - self.assertGreaterEqual(url_sig_timestamp, time_before) - self.assertLessEqual(url_sig_timestamp, time_after) + self.assertGreaterEqual(url_sig_timestamp, timer.start_time) + self.assertLessEqual(url_sig_timestamp, timer.end_time) signer = RegisterSigner(verification_data) signer.verify() @@ -156,11 +153,9 @@ def test_register_ok(self): def test_register_with_custom_verification_url_ok(self): data = self._get_register_user_data(password='testpassword') request = self.create_post_request(data) - time_before = math.floor(time.time()) - with self.assert_one_mail_sent() as sent_emails: + with self.assert_one_mail_sent() as sent_emails, self.timer() as timer: response = self.view_func(request) - time_after = math.ceil(time.time()) - self.assert_valid_response(response, status.HTTP_201_CREATED) + self.assert_valid_response(response, status.HTTP_201_CREATED) user_id = response.data['id'] # Check database state. user = self.user_class.objects.get(id=user_id) @@ -182,8 +177,8 @@ def test_register_with_custom_verification_url_ok(self): url_user_id = int(verification_data['user_id']) self.assertEqual(url_user_id, user_id) url_sig_timestamp = int(verification_data['timestamp']) - self.assertGreaterEqual(url_sig_timestamp, time_before) - self.assertLessEqual(url_sig_timestamp, time_after) + self.assertGreaterEqual(url_sig_timestamp, timer.start_time) + self.assertLessEqual(url_sig_timestamp, timer.end_time) signer = RegisterSigner(verification_data) signer.verify() @@ -193,11 +188,9 @@ def test_register_with_custom_verification_url_ok(self): def test_register_with_html_email_ok(self): data = self._get_register_user_data(password='testpassword') request = self.create_post_request(data) - time_before = math.floor(time.time()) - with self.assert_one_mail_sent() as sent_emails: + with self.assert_one_mail_sent() as sent_emails, self.timer() as timer: response = self.view_func(request) - time_after = math.ceil(time.time()) - self.assert_valid_response(response, status.HTTP_201_CREATED) + self.assert_valid_response(response, status.HTTP_201_CREATED) user_id = response.data['id'] # Check database state. user = self.user_class.objects.get(id=user_id) @@ -218,8 +211,8 @@ def test_register_with_html_email_ok(self): url_user_id = int(verification_data['user_id']) self.assertEqual(url_user_id, user_id) url_sig_timestamp = int(verification_data['timestamp']) - self.assertGreaterEqual(url_sig_timestamp, time_before) - self.assertLessEqual(url_sig_timestamp, time_after) + self.assertGreaterEqual(url_sig_timestamp, timer.start_time) + self.assertLessEqual(url_sig_timestamp, timer.end_time) signer = RegisterSigner(verification_data) signer.verify() @@ -234,11 +227,9 @@ def test_register_no_password_confirm_ok(self): data = self._get_register_user_data(password='testpassword') data.pop('password_confirm') request = self.create_post_request(data) - time_before = math.floor(time.time()) - with self.assert_one_mail_sent() as sent_emails: + with self.assert_one_mail_sent() as sent_emails, self.timer() as timer: response = self.view_func(request) self.assert_valid_response(response, status.HTTP_201_CREATED) - time_after = math.ceil(time.time()) user_id = response.data['id'] # Check database state. user = self.user_class.objects.get(id=user_id) @@ -259,8 +250,8 @@ def test_register_no_password_confirm_ok(self): url_user_id = int(verification_data['user_id']) self.assertEqual(url_user_id, user_id) url_sig_timestamp = int(verification_data['timestamp']) - self.assertGreaterEqual(url_sig_timestamp, time_before) - self.assertLessEqual(url_sig_timestamp, time_after) + self.assertGreaterEqual(url_sig_timestamp, timer.start_time) + self.assertLessEqual(url_sig_timestamp, timer.end_time) signer = RegisterSigner(verification_data) signer.verify() diff --git a/tests/api/test_register_email.py b/tests/api/test_register_email.py index 5221542..9453ab6 100644 --- a/tests/api/test_register_email.py +++ b/tests/api/test_register_email.py @@ -1,4 +1,3 @@ -import math import time from unittest.mock import patch @@ -43,11 +42,9 @@ def test_ok(self): data = { 'email': self.new_email, } - time_before = math.floor(time.time()) - with self.assert_one_mail_sent() as sent_emails: + with self.assert_one_mail_sent() as sent_emails, self.timer() as timer: response = self._test_authenticated(data) self.assert_valid_response(response, status.HTTP_200_OK) - time_after = math.ceil(time.time()) # Check database state. self.user.refresh_from_db() self.assertEqual(self.user.email, self.email) @@ -67,8 +64,8 @@ def test_ok(self): self.assertEqual(verification_data['email'], self.new_email) self.assertEqual(int(verification_data['user_id']), self.user.id) url_sig_timestamp = int(verification_data['timestamp']) - self.assertGreaterEqual(url_sig_timestamp, time_before) - self.assertLessEqual(url_sig_timestamp, time_after) + self.assertGreaterEqual(url_sig_timestamp, timer.start_time) + self.assertLessEqual(url_sig_timestamp, timer.end_time) signer = RegisterEmailSigner(verification_data) signer.verify() diff --git a/tests/utils.py b/tests/utils.py index 1c20f16..1b2b0d6 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -89,8 +89,10 @@ def assert_no_mail_sent(self): def timer(self, get_current_timestamp=get_current_timestamp): timer = Timer(get_current_timestamp=get_current_timestamp) timer.set_start_time() - yield timer - timer.set_end_time() + try: + yield timer + finally: + timer.set_end_time() def _assert_urls_in_text(self, text, expected_num, line_url_pattern): lines = [line.rstrip() for line in text.split('\n')]