Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions promo_code/business/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
import rest_framework.exceptions
import rest_framework.serializers
import rest_framework.status
import rest_framework_simplejwt.exceptions
import rest_framework_simplejwt.serializers
import rest_framework_simplejwt.tokens
import rest_framework_simplejwt.views


class CompanySignUpSerializer(rest_framework.serializers.ModelSerializer):
Expand Down Expand Up @@ -90,3 +94,39 @@ def validate(self, attrs):
)

return attrs


class CompanyTokenRefreshSerializer(
rest_framework_simplejwt.serializers.TokenRefreshSerializer,
):
def validate(self, attrs):
refresh = rest_framework_simplejwt.tokens.RefreshToken(
attrs['refresh'],
)
user_type = refresh.payload.get('user_type', 'user')

if user_type != 'company':
raise rest_framework_simplejwt.exceptions.InvalidToken(
'This refresh endpoint is for company tokens only',
)

company_id = refresh.payload.get('company_id')
if not company_id:
raise rest_framework_simplejwt.exceptions.InvalidToken(
'Company ID missing in token',
)

try:
company = business_models.Company.objects.get(id=company_id)
except business_models.Company.DoesNotExist:
raise rest_framework_simplejwt.exceptions.InvalidToken(
'Company not found',
)

token_version = refresh.payload.get('token_version', 0)
if company.token_version != token_version:
raise rest_framework_simplejwt.exceptions.InvalidToken(
'Token is blacklisted',
)

return super().validate(attrs)
5 changes: 4 additions & 1 deletion promo_code/business/tests/auth/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@ class BaseBusinessAuthTestCase(rest_framework.test.APITestCase):
def setUpTestData(cls):
super().setUpTestData()
cls.client = rest_framework.test.APIClient()
cls.company_refresh_url = django.urls.reverse(
'api-business:company-token-refresh',
)
cls.protected_url = django.urls.reverse('api-core:protected')
cls.signup_url = django.urls.reverse('api-business:company-sign-up')
cls.signin_url = django.urls.reverse('api-business:company-sign-in')
cls.protected_url = django.urls.reverse('api-core:protected')
cls.valid_data = {
'name': 'Digital Marketing Solutions Inc.',
'email': 'testcompany@example.com',
Expand Down
193 changes: 193 additions & 0 deletions promo_code/business/tests/auth/test_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
import business.tests.auth.base
import rest_framework.status
import rest_framework.test
import rest_framework_simplejwt.tokens

import user.models


class JWTTests(business.tests.auth.base.BaseBusinessAuthTestCase):
Expand Down Expand Up @@ -82,3 +85,193 @@ def test_registration_token_invalid_after_login(self):
response.status_code,
rest_framework.status.HTTP_200_OK,
)


class TestCompanyTokenRefresh(
business.tests.auth.base.BaseBusinessAuthTestCase,
):
def setUp(self):
super().setUp()

self.company = business.models.Company.objects.create_company(
name='Digital Marketing Solutions Inc.',
email='testcompany@example.com',
password='SuperStrongPassword2000!',
token_version=1,
)

self.company_data = {
'email': 'testcompany@example.com',
'password': 'SuperStrongPassword2000!',
}

self.company_refresh = rest_framework_simplejwt.tokens.RefreshToken()
self.company_refresh.payload.update(
{
'user_type': 'company',
'company_id': self.company.id,
'token_version': self.company.token_version,
},
)

self.user = user.models.User.objects.create_user(
email='minecraft.digger@gmail.com',
name='Steve',
surname='Jobs',
password='SuperStrongPassword2000!',
other={'age': 23, 'country': 'gb'},
)
self.user_refresh = (
rest_framework_simplejwt.tokens.RefreshToken.for_user(self.user)
)
self.user_refresh.payload['user_type'] = 'user'

def test_successful_company_token_refresh(self):
response = self.client.post(
self.company_refresh_url,
{'refresh': str(self.company_refresh)},
)

self.assertEqual(
response.status_code,
rest_framework.status.HTTP_200_OK,
)
self.assertIn('access', response.data)
self.assertIn('refresh', response.data)

self.assertNotEqual(self.company_refresh, response.data['refresh'])

def test_reject_user_tokens(self):
response = self.client.post(
self.company_refresh_url,
{'refresh': str(self.user_refresh)},
)

self.assertEqual(
response.status_code,
rest_framework.status.HTTP_401_UNAUTHORIZED,
)
self.assertIn(
'This refresh endpoint is for company tokens only',
str(response.content),
)

def test_token_version_mismatch(self):
self.company.token_version = 2
self.company.save()

response = self.client.post(
self.company_refresh_url,
{'refresh': str(self.company_refresh)},
)

self.assertEqual(
response.status_code,
rest_framework.status.HTTP_401_UNAUTHORIZED,
)
self.assertIn('Token is blacklisted', str(response.content))

def test_missing_company_id(self):
invalid_refresh = rest_framework_simplejwt.tokens.RefreshToken()
invalid_refresh.payload.update(
{'user_type': 'company', 'token_version': 1},
)

response = self.client.post(
self.company_refresh_url,
{'refresh': str(invalid_refresh)},
)

self.assertEqual(
response.status_code,
rest_framework.status.HTTP_401_UNAUTHORIZED,
)
self.assertIn(
'Company ID missing in token',
str(response.content.decode()),
)

def test_company_not_found(self):
invalid_refresh = rest_framework_simplejwt.tokens.RefreshToken()
invalid_refresh.payload.update(
{'user_type': 'company', 'company_id': 999, 'token_version': 1},
)

response = self.client.post(
self.company_refresh_url,
{'refresh': str(invalid_refresh)},
)

self.assertEqual(
response.status_code,
rest_framework.status.HTTP_401_UNAUTHORIZED,
)
self.assertIn('Company not found', str(response.content))

def test_refresh_token_invalidation_after_new_login(self):
first_login_response = self.client.post(
self.signin_url,
self.company_data,
format='json',
)
refresh_token_v1 = first_login_response.data['refresh']

second_login_response = self.client.post(
self.signin_url,
self.company_data,
format='json',
)
refresh_token_v2 = second_login_response.data['refresh']

refresh_response_v1 = self.client.post(
self.company_refresh_url,
{'refresh': refresh_token_v1},
format='json',
)
self.assertEqual(
refresh_response_v1.status_code,
rest_framework.status.HTTP_401_UNAUTHORIZED,
)
self.assertEqual(refresh_response_v1.data['code'], 'token_not_valid')
self.assertEqual(
str(refresh_response_v1.data['detail']),
'Token is blacklisted',
)

refresh_response_v2 = self.client.post(
self.company_refresh_url,
{'refresh': refresh_token_v2},
format='json',
)
self.assertEqual(
refresh_response_v2.status_code,
rest_framework.status.HTTP_200_OK,
)
self.assertIn('access', refresh_response_v2.data)

self.client.credentials(
HTTP_AUTHORIZATION='Bearer ' + first_login_response.data['access'],
)
protected_response = self.client.get(self.protected_url)
self.assertEqual(
protected_response.status_code,
rest_framework.status.HTTP_401_UNAUTHORIZED,
)

def test_default_user_type_handling(self):
refresh = rest_framework_simplejwt.tokens.RefreshToken.for_user(
self.user,
)
response = self.client.post(
self.company_refresh_url,
{'refresh': str(refresh)},
)

self.assertEqual(
response.status_code,
rest_framework.status.HTTP_401_UNAUTHORIZED,
)
self.assertIn(
'This refresh endpoint is for company tokens only',
str(response.content),
)
5 changes: 5 additions & 0 deletions promo_code/business/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,9 @@
business.views.CompanySignInView.as_view(),
name='company-sign-in',
),
django.urls.path(
'token/refresh',
business.views.CompanyTokenRefreshView.as_view(),
name='company-token-refresh',
),
]
4 changes: 4 additions & 0 deletions promo_code/business/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,7 @@ def post(self, request):
response_data,
status=rest_framework.status.HTTP_200_OK,
)


class CompanyTokenRefreshView(rest_framework_simplejwt.views.TokenRefreshView):
serializer_class = business.serializers.CompanyTokenRefreshSerializer
2 changes: 1 addition & 1 deletion promo_code/user/tests/auth/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def setUpTestData(cls):
super().setUpTestData()
cls.client = rest_framework.test.APIClient()
cls.protected_url = django.urls.reverse('api-core:protected')
cls.refresh_url = django.urls.reverse('api-user:token_refresh')
cls.refresh_url = django.urls.reverse('api-user:user-token-refresh')
cls.signup_url = django.urls.reverse('api-user:sign-up')
cls.signin_url = django.urls.reverse('api-user:sign-in')

Expand Down
2 changes: 1 addition & 1 deletion promo_code/user/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@
django.urls.path(
'token/refresh/',
rest_framework_simplejwt.views.TokenRefreshView.as_view(),
name='token_refresh',
name='user-token-refresh',
),
]