Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 6 additions & 10 deletions promo_code/business/pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,15 @@ class CustomLimitOffsetPagination(
max_limit = 100

def get_limit(self, request):
param_limit = request.query_params.get(self.limit_query_param)
if param_limit is not None:
limit = int(param_limit)
raw_limit = request.query_params.get(self.limit_query_param)

if limit == 0:
return 0
if raw_limit is None:
return self.default_limit

if self.max_limit:
return min(limit, self.max_limit)
limit = int(raw_limit)

return limit

return self.default_limit
# Allow 0, otherwise cut by max_limit
return 0 if limit == 0 else min(limit, self.max_limit)

def get_paginated_response(self, data):
response = rest_framework.response.Response(data)
Expand Down
117 changes: 117 additions & 0 deletions promo_code/business/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,123 @@ def to_representation(self, instance):
return data


class PromoListQuerySerializer(rest_framework.serializers.Serializer):
"""
Serializer for validating query parameters of promo list requests.
"""

limit = rest_framework.serializers.CharField(
required=False,
allow_blank=True,
)
offset = rest_framework.serializers.CharField(
required=False,
allow_blank=True,
)
sort_by = rest_framework.serializers.ChoiceField(
choices=['active_from', 'active_until'],
required=False,
)
country = rest_framework.serializers.CharField(
required=False,
allow_blank=True,
)

_allowed_params = None

def get_allowed_params(self):
if self._allowed_params is None:
self._allowed_params = set(self.fields.keys())
return self._allowed_params

def validate(self, attrs):
query_params = self.initial_data
allowed_params = self.get_allowed_params()

unexpected_params = set(query_params.keys()) - allowed_params
if unexpected_params:
raise rest_framework.exceptions.ValidationError('Invalid params.')

field_errors = {}

attrs = self._validate_int_field('limit', attrs, field_errors)
attrs = self._validate_int_field('offset', attrs, field_errors)

self._validate_country(query_params, attrs, field_errors)

if field_errors:
raise rest_framework.exceptions.ValidationError(field_errors)

return attrs

def _validate_int_field(self, field_name, attrs, field_errors):
value_str = self.initial_data.get(field_name)
if value_str is None:
return attrs

if value_str == '':
raise rest_framework.exceptions.ValidationError(
f'Invalid {field_name} format.',
)

try:
value_int = int(value_str)
if value_int < 0:
raise rest_framework.exceptions.ValidationError(
f'{field_name.capitalize()} cannot be negative.',
)
attrs[field_name] = value_int
except (ValueError, TypeError):
raise rest_framework.exceptions.ValidationError(
f'Invalid {field_name} format.',
)

return attrs

def _validate_country(self, query_params, attrs, field_errors):
countries_raw = query_params.getlist('country', [])

if '' in countries_raw:
raise rest_framework.exceptions.ValidationError(
'Invalid country format.',
)

country_codes = []
invalid_codes = []

for country_group in countries_raw:
if not country_group.strip():
continue

parts = [part.strip() for part in country_group.split(',')]

if '' in parts:
raise rest_framework.exceptions.ValidationError(
'Invalid country format.',
)

country_codes.extend(parts)

country_codes_upper = [c.upper() for c in country_codes]

for code in country_codes_upper:
if len(code) != 2:
invalid_codes.append(code)
continue
try:
pycountry.countries.lookup(code)
except LookupError:
invalid_codes.append(code)

if invalid_codes:
field_errors['country'] = (
f'Invalid country codes: {", ".join(invalid_codes)}'
)

attrs['countries'] = country_codes
attrs.pop('country', None)


class PromoReadOnlySerializer(rest_framework.serializers.ModelSerializer):
promo_id = rest_framework.serializers.UUIDField(
source='id',
Expand Down
132 changes: 15 additions & 117 deletions promo_code/business/views.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
import re

import django.db.models
import pycountry
import rest_framework.exceptions
import rest_framework.generics
import rest_framework.permissions
import rest_framework.response
import rest_framework.serializers
import rest_framework.status
import rest_framework.views
import rest_framework_simplejwt.exceptions
import rest_framework_simplejwt.tokens
import rest_framework_simplejwt.views
Expand Down Expand Up @@ -155,14 +152,21 @@ class CompanyPromoListView(rest_framework.generics.ListAPIView):
serializer_class = business.serializers.PromoReadOnlySerializer
pagination_class = business.pagination.CustomLimitOffsetPagination

def initial(self, request, *args, **kwargs):
super().initial(request, *args, **kwargs)

serializer = business.serializers.PromoListQuerySerializer(
data=request.query_params,
)
serializer.is_valid(raise_exception=True)
request.validated_query_params = serializer.validated_data

def get_queryset(self):
params = self.request.validated_query_params
countries = [c.upper() for c in params.get('countries', [])]
sort_by = params.get('sort_by')

queryset = business.models.Promo.objects.for_company(self.request.user)
countries = [
country.strip()
for group in self.request.query_params.getlist('country', [])
for country in group.split(',')
if country.strip()
]

if countries:
regex_pattern = r'(' + '|'.join(map(re.escape, countries)) + ')'
Expand All @@ -171,115 +175,9 @@ def get_queryset(self):
| django.db.models.Q(target__country__isnull=True),
)

sort_by = self.request.query_params.get('sort_by')
if sort_by in ['active_from', 'active_until']:
queryset = queryset.order_by(f'-{sort_by}')
else:
queryset = queryset.order_by('-created_at') # noqa: R504

return queryset # noqa: R504

def list(self, request, *args, **kwargs):
try:
self.validate_query_params()
except rest_framework.exceptions.ValidationError as e:
return rest_framework.response.Response(
e.detail,
status=rest_framework.status.HTTP_400_BAD_REQUEST,
)

return super().list(request, *args, **kwargs)

def validate_query_params(self):
self._validate_allowed_params()
errors = {}
self._validate_countries(errors)
self._validate_sort_by(errors)
self._validate_offset()
self._validate_limit()
if errors:
raise rest_framework.exceptions.ValidationError(errors)

def _validate_allowed_params(self):
allowed_params = {'country', 'limit', 'offset', 'sort_by'}
unexpected_params = (
set(self.request.query_params.keys()) - allowed_params
)

if unexpected_params:
raise rest_framework.exceptions.ValidationError('Invalid params.')

def _validate_countries(self, errors):
countries = self.request.query_params.getlist('country', [])
country_list = []

for country_group in countries:
parts = [part.strip() for part in country_group.split(',')]

if any(part == '' for part in parts):
raise rest_framework.exceptions.ValidationError(
'Invalid country format.',
)

country_list.extend(parts)

country_list = [c.strip().upper() for c in country_list if c.strip()]

invalid_countries = []

for code in country_list:
if len(code) != 2:
invalid_countries.append(code)
continue

try:
pycountry.countries.lookup(code)
except LookupError:
invalid_countries.append(code)

if invalid_countries:
errors['country'] = (
f'Invalid country codes: {", ".join(invalid_countries)}'
)

def _validate_sort_by(self, errors):
sort_by = self.request.query_params.get('sort_by')
if sort_by and sort_by not in ['active_from', 'active_until']:
errors['sort_by'] = (
'Invalid sort_by parameter. '
'Available values: active_from, active_until'
)
ordering = f'-{sort_by}' if sort_by else '-created_at'

def _validate_offset(self):
offset = self.request.query_params.get('offset')
if offset is not None:
try:
offset = int(offset)
except (TypeError, ValueError):
raise rest_framework.exceptions.ValidationError(
'Invalid offset format.',
)

if offset < 0:
raise rest_framework.exceptions.ValidationError(
'Offset cannot be negative.',
)

def _validate_limit(self):
limit = self.request.query_params.get('limit')

if limit is not None:
try:
limit = int(limit)
except (TypeError, ValueError):
raise rest_framework.exceptions.ValidationError(
'Invalid limit format.',
)

if limit < 0:
raise rest_framework.exceptions.ValidationError(
'Limit cannot be negative.',
)
return queryset.order_by(ordering)


class CompanyPromoDetailView(rest_framework.generics.RetrieveUpdateAPIView):
Expand Down