Skip to content

Commit

Permalink
Resolved issue #39
Browse files Browse the repository at this point in the history
Added VERIFICATION_URL_BUILDER setting
  • Loading branch information
apragacz committed Mar 31, 2019
1 parent 31d8960 commit e72d674
Show file tree
Hide file tree
Showing 7 changed files with 150 additions and 26 deletions.
14 changes: 14 additions & 0 deletions rest_registration/settings_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,20 @@ def __new__(cls, name, *, default=None, help=None, import_string=False):
default='rest_registration.utils.html.convert_html_to_text_preserving_urls', # noqa: E501
import_string=True,
),
Field(
'VERIFICATION_URL_BUILDER',
default='rest_registration.utils.verification.build_default_verification_url', # noqa: E501
import_string=True,
help=dedent("""\
The builder function receives the ``signer`` object and construct
the url using ``signer.get_base_url()``
and ``signer.get_signed_data()``. The default url builder will use
the base url and append the signed data as HTTP GET query string.
It is be solely up to the implementer of custom builder function
to encode the signed values properly in the URL.
"""),

),
]

CHANGE_PASSWORD_SETTINGS_FIELDS = [
Expand Down
11 changes: 11 additions & 0 deletions rest_registration/utils/verification.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from urllib.parse import urlencode

from django.core.signing import BadSignature, SignatureExpired

from rest_registration.exceptions import BadRequest
Expand All @@ -10,3 +12,12 @@ def verify_signer_or_bad_request(signer):
raise BadRequest('Signature expired')
except BadSignature:
raise BadRequest('Invalid signature')


def build_default_verification_url(signer):
base_url = signer.get_base_url()
params = urlencode(signer.get_signed_data())
url = '{base_url}?{params}'.format(base_url=base_url, params=params)
if signer.request:
url = signer.request.build_absolute_uri(url)
return url
11 changes: 4 additions & 7 deletions rest_registration/verification.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import pickle
import time
from urllib.parse import urlencode

from django.core.signing import BadSignature, SignatureExpired, Signer
from django.utils.crypto import constant_time_compare

from rest_registration.settings import registration_settings

PICKLE_REPR_PROTOCOL = 4


Expand Down Expand Up @@ -85,9 +86,5 @@ def __init__(self, data, request=None, strict=True):
self.request = request

def get_url(self):
base_url = self.get_base_url()
params = urlencode(self.get_signed_data())
url = '{base_url}?{params}'.format(base_url=base_url, params=params)
if self.request:
url = self.request.build_absolute_uri(url)
return url
url_builder = registration_settings.VERIFICATION_URL_BUILDER
return url_builder(self)
30 changes: 22 additions & 8 deletions tests/api/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,16 +83,30 @@ def assert_response_is_not_found(self, response):
)

def assert_valid_verification_url(
self, url, expected_path=None, expected_query_keys=None):
parsed_url = urlparse(url)
self, url, expected_path=None, expected_fields=None,
url_parser=None):
if url_parser is None:
url_parser = self._parse_verification_url
try:
url_path, verification_data = url_parser(url, expected_fields)
except ValueError as e:
self.fail(str(e))
if expected_path is not None:
self.assertEqual(parsed_url.path, expected_path)
self.assertEqual(url_path, expected_path)
if expected_fields is not None:
self.assertSetEqual(
set(verification_data.keys()), set(expected_fields))
return verification_data

def _parse_verification_url(self, url, verification_field_names):
parsed_url = urlparse(url)
query = parse_qs(parsed_url.query, strict_parsing=True)
if expected_query_keys is not None:
self.assertSetEqual(set(query), set(expected_query_keys))

for values in query.values():
self.assert_len_equals(values, 1)
for key, values in query.items():
if len(values) == 0:
raise ValueError("no values for '{key}".format(key=key))
if len(values) > 1:
raise ValueError("multiple values for '{key}'".format(key=key))

verification_data = {key: values[0] for key, values in query.items()}
return verification_data
return parsed_url.path, verification_data
106 changes: 97 additions & 9 deletions tests/api/test_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
import time
from unittest import mock
from unittest.mock import patch
from urllib.parse import quote_plus as urlquote
from urllib.parse import unquote_plus as urlunquote
from urllib.parse import urlparse

from django.test.utils import override_settings
from rest_framework import status

from rest_registration.api.views.register import RegisterSigner
from rest_registration.settings import registration_settings
from tests.utils import shallow_merge_dicts
from tests.utils import TestCase, shallow_merge_dicts

from .base import APIViewTestCase

Expand Down Expand Up @@ -36,10 +39,9 @@


@override_settings(REST_REGISTRATION=REST_REGISTRATION_WITH_VERIFICATION)
class RegisterViewTestCase(APIViewTestCase):
VIEW_NAME = 'register'
class RegisterSerializerTestCase(TestCase):

def test_register_serializer_ok(self):
def test_ok(self):
serializer_class = registration_settings.REGISTER_SERIALIZER_CLASS
serializer = serializer_class(data={})
field_names = {f for f in serializer.get_fields()}
Expand All @@ -56,7 +58,7 @@ def test_register_serializer_ok(self):
},
),
)
def test_register_serializer_no_password_ok(self):
def test_no_password_ok(self):
serializer_class = registration_settings.REGISTER_SERIALIZER_CLASS
serializer = serializer_class(data={})
field_names = {f for f in serializer.get_fields()}
Expand All @@ -65,6 +67,52 @@ def test_register_serializer_no_password_ok(self):
{'id', 'username', 'first_name', 'last_name', 'email', 'password'},
)


def build_custom_verification_url(signer):
base_url = signer.get_base_url()
signed_data = signer.get_signed_data()
if signer.USE_TIMESTAMP:
timestamp = signed_data.pop(signer.TIMESTAMP_FIELD)
else:
timestamp = None
signature = signed_data.pop(signer.SIGNATURE_FIELD)
segments = [signed_data[k] for k in sorted(signed_data.keys())]
segments.append(signature)
if timestamp:
segments.append(timestamp)
quoted_segments = [urlquote(str(s)) for s in segments]

url = base_url
if not url.endswith('/'):
url += '/'
url += '/'.join(quoted_segments)
url += '/'
if signer.request:
url = signer.request.build_absolute_uri(url)

return url


def parse_custom_verification_url(url, verification_field_names):
parsed_url = urlparse(url)
num_of_fields = len(verification_field_names)
url_path = parsed_url.path.rstrip('/')
url_segments = url_path.rsplit('/', num_of_fields)
if len(url_segments) != num_of_fields + 1:
raise ValueError("Could not parse {url}".format(url=url))

data_segments = url_segments[1:]
url_path = url_segments[0] + '/'
verification_data = {
name: urlunquote(value)
for name, value in zip(verification_field_names, data_segments)}
return url_path, verification_data


@override_settings(REST_REGISTRATION=REST_REGISTRATION_WITH_VERIFICATION)
class RegisterViewTestCase(APIViewTestCase):
VIEW_NAME = 'register'

def test_register_ok(self):
data = self._get_register_user_data(password='testpassword')
request = self.create_post_request(data)
Expand All @@ -88,7 +136,48 @@ def test_register_ok(self):
verification_data = self.assert_valid_verification_url(
url,
expected_path=REGISTER_VERIFICATION_URL,
expected_query_keys={'signature', 'user_id', 'timestamp'},
expected_fields={'signature', 'user_id', 'timestamp'},
)
url_user_id = int(verification_data['user_id'])
self.assertEqual(url_user_id, user_id)
url_sig_timestamp = int(verification_data['timestamp'])
self.assertGreaterEqual(url_sig_timestamp, time_before)
self.assertLessEqual(url_sig_timestamp, time_after)
signer = RegisterSigner(verification_data)
signer.verify()

@override_settings(
REST_REGISTRATION=shallow_merge_dicts(
REST_REGISTRATION_WITH_VERIFICATION, {
'VERIFICATION_URL_BUILDER': build_custom_verification_url,
},
),
)
def test_register_with_custom_verification_url_ok(self):
data = self._get_register_user_data(password='testpassword')
request = self.create_post_request(data)
time_before = math.floor(time.time())
with self.assert_one_mail_sent() as sent_emails:
response = self.view_func(request)
time_after = math.ceil(time.time())
self.assert_valid_response(response, status.HTTP_201_CREATED)
user_id = response.data['id']
# Check database state.
user = self.user_class.objects.get(id=user_id)
self.assertEqual(user.username, data['username'])
self.assertTrue(user.check_password(data['password']))
self.assertFalse(user.is_active)
# Check verification e-mail.
sent_email = sent_emails[0]
self.assertEqual(sent_email.from_email, VERIFICATION_FROM_EMAIL)
self.assertListEqual(sent_email.to, [data['email']])
url = self.assert_one_url_line_in_text(sent_email.body)

verification_data = self.assert_valid_verification_url(
url,
expected_path=REGISTER_VERIFICATION_URL,
expected_fields=['user_id', 'signature', 'timestamp'],
url_parser=parse_custom_verification_url,
)
url_user_id = int(verification_data['user_id'])
self.assertEqual(url_user_id, user_id)
Expand All @@ -98,7 +187,6 @@ def test_register_ok(self):
signer = RegisterSigner(verification_data)
signer.verify()

# TODO: unskip this test when &times entity problem will be fixed.
@override_settings(
REST_REGISTRATION=REST_REGISTRATION_WITH_HTML_EMAIL_VERIFICATION,
)
Expand All @@ -125,7 +213,7 @@ def test_register_with_html_email_ok(self):
verification_data = self.assert_valid_verification_url(
url,
expected_path=REGISTER_VERIFICATION_URL,
expected_query_keys={'signature', 'user_id', 'timestamp'},
expected_fields={'signature', 'user_id', 'timestamp'},
)
url_user_id = int(verification_data['user_id'])
self.assertEqual(url_user_id, user_id)
Expand Down Expand Up @@ -166,7 +254,7 @@ def test_register_no_password_confirm_ok(self):
verification_data = self.assert_valid_verification_url(
url,
expected_path=REGISTER_VERIFICATION_URL,
expected_query_keys={'signature', 'user_id', 'timestamp'},
expected_fields={'signature', 'user_id', 'timestamp'},
)
url_user_id = int(verification_data['user_id'])
self.assertEqual(url_user_id, user_id)
Expand Down
2 changes: 1 addition & 1 deletion tests/api/test_register_email.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_ok(self):
verification_data = self.assert_valid_verification_url(
url,
expected_path=REGISTER_EMAIL_VERIFICATION_URL,
expected_query_keys={'signature', 'user_id', 'timestamp', 'email'},
expected_fields={'signature', 'user_id', 'timestamp', 'email'},
)
self.assertEqual(verification_data['email'], self.new_email)
self.assertEqual(int(verification_data['user_id']), self.user.id)
Expand Down
2 changes: 1 addition & 1 deletion tests/api/test_reset_password.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def _assert_valid_send_link_email(
verification_data = self.assert_valid_verification_url(
url,
expected_path=RESET_PASSWORD_VERIFICATION_URL,
expected_query_keys={'signature', 'user_id', 'timestamp'},
expected_fields={'signature', 'user_id', 'timestamp'},
)
self.assertEqual(int(verification_data['user_id']), user.id)
url_sig_timestamp = int(verification_data['timestamp'])
Expand Down

0 comments on commit e72d674

Please sign in to comment.