diff --git a/promo_code/business/pagination.py b/promo_code/business/pagination.py index 68eb65a..120f148 100644 --- a/promo_code/business/pagination.py +++ b/promo_code/business/pagination.py @@ -9,19 +9,15 @@ class CustomLimitOffsetPagination( max_limit = 100 def get_limit(self, request): - param_limit = request.query_params.get(self.limit_query_param) - if param_limit is not None: - limit = int(param_limit) + raw_limit = request.query_params.get(self.limit_query_param) - if limit == 0: - return 0 + if raw_limit is None: + return self.default_limit - if self.max_limit: - return min(limit, self.max_limit) + limit = int(raw_limit) - return limit - - return self.default_limit + # Allow 0, otherwise cut by max_limit + return 0 if limit == 0 else min(limit, self.max_limit) def get_paginated_response(self, data): response = rest_framework.response.Response(data) diff --git a/promo_code/business/serializers.py b/promo_code/business/serializers.py index 5da9475..7ac7175 100644 --- a/promo_code/business/serializers.py +++ b/promo_code/business/serializers.py @@ -282,6 +282,123 @@ def to_representation(self, instance): return data +class PromoListQuerySerializer(rest_framework.serializers.Serializer): + """ + Serializer for validating query parameters of promo list requests. + """ + + limit = rest_framework.serializers.CharField( + required=False, + allow_blank=True, + ) + offset = rest_framework.serializers.CharField( + required=False, + allow_blank=True, + ) + sort_by = rest_framework.serializers.ChoiceField( + choices=['active_from', 'active_until'], + required=False, + ) + country = rest_framework.serializers.CharField( + required=False, + allow_blank=True, + ) + + _allowed_params = None + + def get_allowed_params(self): + if self._allowed_params is None: + self._allowed_params = set(self.fields.keys()) + return self._allowed_params + + def validate(self, attrs): + query_params = self.initial_data + allowed_params = self.get_allowed_params() + + unexpected_params = set(query_params.keys()) - allowed_params + if unexpected_params: + raise rest_framework.exceptions.ValidationError('Invalid params.') + + field_errors = {} + + attrs = self._validate_int_field('limit', attrs, field_errors) + attrs = self._validate_int_field('offset', attrs, field_errors) + + self._validate_country(query_params, attrs, field_errors) + + if field_errors: + raise rest_framework.exceptions.ValidationError(field_errors) + + return attrs + + def _validate_int_field(self, field_name, attrs, field_errors): + value_str = self.initial_data.get(field_name) + if value_str is None: + return attrs + + if value_str == '': + raise rest_framework.exceptions.ValidationError( + f'Invalid {field_name} format.', + ) + + try: + value_int = int(value_str) + if value_int < 0: + raise rest_framework.exceptions.ValidationError( + f'{field_name.capitalize()} cannot be negative.', + ) + attrs[field_name] = value_int + except (ValueError, TypeError): + raise rest_framework.exceptions.ValidationError( + f'Invalid {field_name} format.', + ) + + return attrs + + def _validate_country(self, query_params, attrs, field_errors): + countries_raw = query_params.getlist('country', []) + + if '' in countries_raw: + raise rest_framework.exceptions.ValidationError( + 'Invalid country format.', + ) + + country_codes = [] + invalid_codes = [] + + for country_group in countries_raw: + if not country_group.strip(): + continue + + parts = [part.strip() for part in country_group.split(',')] + + if '' in parts: + raise rest_framework.exceptions.ValidationError( + 'Invalid country format.', + ) + + country_codes.extend(parts) + + country_codes_upper = [c.upper() for c in country_codes] + + for code in country_codes_upper: + if len(code) != 2: + invalid_codes.append(code) + continue + try: + pycountry.countries.lookup(code) + except LookupError: + invalid_codes.append(code) + + if invalid_codes: + field_errors['country'] = ( + f'Invalid country codes: {", ".join(invalid_codes)}' + ) + + attrs['countries'] = country_codes + attrs.pop('country', None) + + class PromoReadOnlySerializer(rest_framework.serializers.ModelSerializer): promo_id = rest_framework.serializers.UUIDField( source='id', diff --git a/promo_code/business/views.py b/promo_code/business/views.py index a9b1cdc..24a47d3 100644 --- a/promo_code/business/views.py +++ b/promo_code/business/views.py @@ -1,14 +1,11 @@ import re import django.db.models -import pycountry -import rest_framework.exceptions import rest_framework.generics import rest_framework.permissions import rest_framework.response import rest_framework.serializers import rest_framework.status -import rest_framework.views import rest_framework_simplejwt.exceptions import rest_framework_simplejwt.tokens import rest_framework_simplejwt.views @@ -155,14 +152,21 @@ class CompanyPromoListView(rest_framework.generics.ListAPIView): serializer_class = business.serializers.PromoReadOnlySerializer pagination_class = business.pagination.CustomLimitOffsetPagination + def initial(self, request, *args, **kwargs): + super().initial(request, *args, **kwargs) + + serializer = business.serializers.PromoListQuerySerializer( + data=request.query_params, + ) + serializer.is_valid(raise_exception=True) + request.validated_query_params = serializer.validated_data + def get_queryset(self): + params = self.request.validated_query_params + countries = [c.upper() for c in params.get('countries', [])] + sort_by = params.get('sort_by') + queryset = business.models.Promo.objects.for_company(self.request.user) - countries = [ - country.strip() - for group in self.request.query_params.getlist('country', []) - for country in group.split(',') - if country.strip() - ] if countries: regex_pattern = r'(' + '|'.join(map(re.escape, countries)) + ')' @@ -171,115 +175,9 @@ def get_queryset(self): | django.db.models.Q(target__country__isnull=True), ) - sort_by = self.request.query_params.get('sort_by') - if sort_by in ['active_from', 'active_until']: - queryset = queryset.order_by(f'-{sort_by}') - else: - queryset = queryset.order_by('-created_at') # noqa: R504 - - return queryset # noqa: R504 - - def list(self, request, *args, **kwargs): - try: - self.validate_query_params() - except rest_framework.exceptions.ValidationError as e: - return rest_framework.response.Response( - e.detail, - status=rest_framework.status.HTTP_400_BAD_REQUEST, - ) - - return super().list(request, *args, **kwargs) - - def validate_query_params(self): - self._validate_allowed_params() - errors = {} - self._validate_countries(errors) - self._validate_sort_by(errors) - self._validate_offset() - self._validate_limit() - if errors: - raise rest_framework.exceptions.ValidationError(errors) - - def _validate_allowed_params(self): - allowed_params = {'country', 'limit', 'offset', 'sort_by'} - unexpected_params = ( - set(self.request.query_params.keys()) - allowed_params - ) - - if unexpected_params: - raise rest_framework.exceptions.ValidationError('Invalid params.') - - def _validate_countries(self, errors): - countries = self.request.query_params.getlist('country', []) - country_list = [] - - for country_group in countries: - parts = [part.strip() for part in country_group.split(',')] - - if any(part == '' for part in parts): - raise rest_framework.exceptions.ValidationError( - 'Invalid country format.', - ) - - country_list.extend(parts) - - country_list = [c.strip().upper() for c in country_list if c.strip()] - - invalid_countries = [] - - for code in country_list: - if len(code) != 2: - invalid_countries.append(code) - continue - - try: - pycountry.countries.lookup(code) - except LookupError: - invalid_countries.append(code) - - if invalid_countries: - errors['country'] = ( - f'Invalid country codes: {", ".join(invalid_countries)}' - ) - - def _validate_sort_by(self, errors): - sort_by = self.request.query_params.get('sort_by') - if sort_by and sort_by not in ['active_from', 'active_until']: - errors['sort_by'] = ( - 'Invalid sort_by parameter. ' - 'Available values: active_from, active_until' - ) + ordering = f'-{sort_by}' if sort_by else '-created_at' - def _validate_offset(self): - offset = self.request.query_params.get('offset') - if offset is not None: - try: - offset = int(offset) - except (TypeError, ValueError): - raise rest_framework.exceptions.ValidationError( - 'Invalid offset format.', - ) - - if offset < 0: - raise rest_framework.exceptions.ValidationError( - 'Offset cannot be negative.', - ) - - def _validate_limit(self): - limit = self.request.query_params.get('limit') - - if limit is not None: - try: - limit = int(limit) - except (TypeError, ValueError): - raise rest_framework.exceptions.ValidationError( - 'Invalid limit format.', - ) - - if limit < 0: - raise rest_framework.exceptions.ValidationError( - 'Limit cannot be negative.', - ) + return queryset.order_by(ordering) class CompanyPromoDetailView(rest_framework.generics.RetrieveUpdateAPIView):