Skip to content

Commit

Permalink
Merge fb713a9 into e62f1f9
Browse files Browse the repository at this point in the history
  • Loading branch information
Soukaina committed Dec 14, 2017
2 parents e62f1f9 + fb713a9 commit 1fa4458
Show file tree
Hide file tree
Showing 9 changed files with 90 additions and 21 deletions.
1 change: 0 additions & 1 deletion rest_framework_expiring_authtoken/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

from rest_framework_expiring_authtoken.settings import token_settings


class ExpiringToken(Token):

"""Extend Token to add an expired method."""
Expand Down
37 changes: 37 additions & 0 deletions rest_framework_expiring_authtoken/serializers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from django.utils.translation import ugettext_lazy as _

from rest_framework import serializers
from django.contrib.auth import authenticate


class EmailAuthTokenSerializer(serializers.Serializer):
email = serializers.CharField(label=_("Email"))
password = serializers.CharField(
label="Password",
style={'input_type': 'password'},
trim_whitespace=False
)

def validate(self, attrs):
email = attrs.get('email')
password = attrs.get('password')

if email and password:
user = authenticate(email=email, password=password)

if user:
# From Django 1.10 onwards the `authenticate` call simply
# returns `None` for is_active=False users.
# (Assuming the default `ModelBackend` authentication backend.)
if not user.is_active:
msg = 'User account is disabled.'
raise serializers.ValidationError(msg)
else:
msg = 'Unable to log in with provided credentials.'
raise serializers.ValidationError(msg)
else:
msg = 'Must include "email" and "password".'
raise serializers.ValidationError(msg)

attrs['user'] = user
return attrs
11 changes: 6 additions & 5 deletions rest_framework_expiring_authtoken/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Classes:
ObtainExpiringAuthToken: View to provide tokens to clients.
"""
from rest_framework.authtoken.serializers import AuthTokenSerializer
from .serializers import EmailAuthTokenSerializer
from rest_framework.authtoken.views import ObtainAuthToken
from rest_framework.response import Response
from rest_framework.status import HTTP_400_BAD_REQUEST
Expand All @@ -13,13 +13,13 @@

class ObtainExpiringAuthToken(ObtainAuthToken):

"""View enabling username/password exchange for expiring token."""
"""View enabling email/password exchange for expiring token."""

model = ExpiringToken

def post(self, request):
"""Respond to POSTed username/password with token."""
serializer = AuthTokenSerializer(data=request.data)
"""Respond to POSTed email/password with token."""
serializer = EmailAuthTokenSerializer(data=request.data)

if serializer.is_valid():
token, _ = ExpiringToken.objects.get_or_create(
Expand All @@ -33,7 +33,8 @@ def post(self, request):
user=serializer.validated_data['user']
)

data = {'token': token.key}
data = {'auth_token': token.key}

return Response(data)

return Response(serializer.errors, status=HTTP_400_BAD_REQUEST)
Expand Down
Empty file modified runtests.py
100644 → 100755
Empty file.
31 changes: 31 additions & 0 deletions tests/backends.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from django.contrib.auth import get_user_model # gets the user_model django default or your own custom
from django.contrib.auth.backends import ModelBackend
from django.db.models import Q


class CustomAuthenticationBackend(ModelBackend): # requires to define two functions authenticate and get_user

def authenticate(self, email=None, password=None, **kwargs):
UserModel = get_user_model()

try:
# below line gives query set,you can change the queryset as per your requirement
user = UserModel.objects.filter(email__iexact=email).distinct()

except UserModel.DoesNotExist:
return None

if user.exists():
user_obj = user.first()
if user_obj.check_password(password):
return user_obj
return None
else:
return None

def get_user(self, user_id):
UserModel = get_user_model()
try:
return UserModel.objects.get(pk=user_id)
except UserModel.DoesNotExist:
return None
2 changes: 2 additions & 0 deletions tests/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
"tests",
]

AUTHENTICATION_BACKENDS = ( 'tests.backends.CustomAuthenticationBackend', )

DATABASES = {
'default': {
'ENGINE': 'django.db.backends.sqlite3',
Expand Down
3 changes: 1 addition & 2 deletions tests/test_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@ class ExpiringTokenAuthenticationTestCase(TestCase):

def setUp(self):
"""Create a user and associated token."""
self.username = 'test'
self.email = 'test@test.com'
self.password = 'test'
self.user = User.objects.create_user(
username=self.username,
email=self.email,
username=self.email,
password=self.password
)

Expand Down
3 changes: 1 addition & 2 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@ class ExpiringTokenTestCase(TestCase):

def setUp(self):
"""Create a user and associated token."""
self.username = 'test'
self.email = 'test@test.com'
self.password = 'test'
self.user = User.objects.create_user(
username=self.username,
email=self.email,
username=self.email,
password=self.password
)

Expand Down
23 changes: 12 additions & 11 deletions tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,12 @@ class ObtainExpiringTokenViewTestCase(APITestCase):

def setUp(self):
"""Create a user."""
self.username = 'test'
self.email = 'test@test.com'
self.password = 'test'
self.user = User.objects.create_user(
username=self.username,
username=self.email,
email=self.email,
password=self.password
password=self.password,
)

def test_post(self):
Expand All @@ -32,22 +31,24 @@ def test_post(self):
response = self.client.post(
'/obtain-token/',
{
'username': self.username,
'email': self.email,
'password': self.password
}
)

#import pdb;pdb.set_trace()

self.assertEqual(response.status_code, status.HTTP_200_OK)

# Check the response contains the token key.
self.assertEqual(token.key, response.data['token'])
self.assertEqual(token.key, response.data['auth_token'])

def test_post_create_token(self):
"""Check token is created if none exists."""
response = self.client.post(
'/obtain-token/',
{
'username': self.username,
'email': self.email,
'password': self.password
}
)
Expand All @@ -57,7 +58,7 @@ def test_post_create_token(self):
# Check token was created and the response contains the token key.
token = ExpiringToken.objects.first()
self.assertEqual(token.user, self.user)
self.assertEqual(response.data['token'], token.key)
self.assertEqual(response.data['auth_token'], token.key)

def test_post_no_credentials(self):
"""Check POST request with no credentials fails."""
Expand All @@ -66,7 +67,7 @@ def test_post_no_credentials(self):
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(response.data,
{
'username': ['This field is required.'],
'email': ['This field is required.'],
'password': ['This field is required.']
}
)
Expand All @@ -76,7 +77,7 @@ def test_post_wrong_credentials(self):
response = self.client.post(
'/obtain-token/',
{
'username': self.username,
'email': self.email,
'password': 'wrong'
}
)
Expand All @@ -101,7 +102,7 @@ def test_post_expired_token(self):
response = self.client.post(
'/obtain-token/',
{
'username': self.username,
'email': self.email,
'password': self.password
}
)
Expand All @@ -112,5 +113,5 @@ def test_post_expired_token(self):
token = ExpiringToken.objects.first()
key_2 = token.key
self.assertEqual(token.user, self.user)
self.assertEqual(response.data['token'], token.key)
self.assertEqual(response.data['auth_token'], token.key)
self.assertTrue(key_1 != key_2)

0 comments on commit 1fa4458

Please sign in to comment.