Skip to content

refactor(core, business, serializers): centralize base serializers and streamline promo/user logic #59

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 16 additions & 13 deletions promo_code/business/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,21 @@ def create_company(self, email, name, password=None, **extra_fields):


class PromoManager(django.db.models.Manager):
with_related_fields = (
'id',
'company__id',
'company__name',
'description',
'image_url',
'target',
'max_count',
'active_from',
'active_until',
'mode',
'promo_common',
'created_at',
)

def get_queryset(self):
return super().get_queryset()

Expand All @@ -30,19 +45,7 @@ def with_related(self):
self.select_related('company')
.prefetch_related('unique_codes')
.only(
'id',
'company',
'description',
'image_url',
'target',
'max_count',
'active_from',
'active_until',
'mode',
'promo_common',
'created_at',
'company__id',
'company__name',
*self.with_related_fields,
)
)

Expand Down
264 changes: 11 additions & 253 deletions promo_code/business/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import django.contrib.auth.password_validation
import django.db.transaction
import pycountry
import rest_framework.exceptions
import rest_framework.serializers
import rest_framework_simplejwt.exceptions
Expand Down Expand Up @@ -140,36 +139,14 @@ def get_active_company_from_token(self, token):
return company


class CountryField(rest_framework.serializers.CharField):
"""
Custom field for validating country codes according to ISO 3166-1 alpha-2.
"""

def __init__(self, **kwargs):
kwargs['allow_blank'] = False
kwargs['min_length'] = business.constants.TARGET_COUNTRY_CODE_LENGTH
kwargs['max_length'] = business.constants.TARGET_COUNTRY_CODE_LENGTH
super().__init__(**kwargs)

def to_internal_value(self, data):
code = super().to_internal_value(data)
try:
pycountry.countries.lookup(code.upper())
except LookupError:
raise rest_framework.serializers.ValidationError(
'Invalid ISO 3166-1 alpha-2 country code.',
)
return code


class MultiCountryField(rest_framework.serializers.ListField):
"""
Custom field for handling multiple country codes,
passed either as a comma-separated list or as multiple parameters.
"""

def __init__(self, **kwargs):
kwargs['child'] = CountryField()
kwargs['child'] = core.serializers.CountryField()
kwargs['allow_empty'] = False
super().__init__(**kwargs)

Expand All @@ -196,234 +173,16 @@ def to_internal_value(self, data):
return super().to_internal_value(data)


class TargetSerializer(rest_framework.serializers.Serializer):
age_from = rest_framework.serializers.IntegerField(
min_value=business.constants.TARGET_AGE_MIN,
max_value=business.constants.TARGET_AGE_MAX,
required=False,
)
age_until = rest_framework.serializers.IntegerField(
min_value=business.constants.TARGET_AGE_MIN,
max_value=business.constants.TARGET_AGE_MAX,
required=False,
)
country = CountryField(required=False)

categories = rest_framework.serializers.ListField(
child=rest_framework.serializers.CharField(
min_length=business.constants.TARGET_CATEGORY_MIN_LENGTH,
max_length=business.constants.TARGET_CATEGORY_MAX_LENGTH,
allow_blank=False,
),
max_length=business.constants.TARGET_CATEGORY_MAX_ITEMS,
required=False,
allow_empty=True,
)

def validate(self, data):
age_from = data.get('age_from')
age_until = data.get('age_until')

if (
age_from is not None
and age_until is not None
and age_from > age_until
):
raise rest_framework.serializers.ValidationError(
{'age_until': 'Must be greater than or equal to age_from.'},
)
return data


class BasePromoSerializer(rest_framework.serializers.ModelSerializer):
"""
Base serializer for promo, containing validation and representation logic.
"""

image_url = rest_framework.serializers.URLField(
required=False,
allow_blank=False,
max_length=business.constants.PROMO_IMAGE_URL_MAX_LENGTH,
)
description = rest_framework.serializers.CharField(
min_length=business.constants.PROMO_DESC_MIN_LENGTH,
max_length=business.constants.PROMO_DESC_MAX_LENGTH,
required=True,
)
target = TargetSerializer(required=True, allow_null=True)
promo_common = rest_framework.serializers.CharField(
min_length=business.constants.PROMO_COMMON_CODE_MIN_LENGTH,
max_length=business.constants.PROMO_COMMON_CODE_MAX_LENGTH,
required=False,
allow_null=True,
allow_blank=False,
)
promo_unique = rest_framework.serializers.ListField(
child=rest_framework.serializers.CharField(
min_length=business.constants.PROMO_UNIQUE_CODE_MIN_LENGTH,
max_length=business.constants.PROMO_UNIQUE_CODE_MAX_LENGTH,
allow_blank=False,
),
min_length=business.constants.PROMO_UNIQUE_LIST_MIN_ITEMS,
max_length=business.constants.PROMO_UNIQUE_LIST_MAX_ITEMS,
required=False,
allow_null=True,
)

class Meta:
model = business.models.Promo
fields = (
'description',
'image_url',
'target',
'max_count',
'active_from',
'active_until',
'mode',
'promo_common',
'promo_unique',
)

def validate(self, data):
"""
Main validation method.
Determines the mode and calls the corresponding validation method.
"""

mode = data.get('mode', getattr(self.instance, 'mode', None))

if mode == business.constants.PROMO_MODE_COMMON:
self._validate_common(data)
elif mode == business.constants.PROMO_MODE_UNIQUE:
self._validate_unique(data)
elif mode is None:
raise rest_framework.serializers.ValidationError(
{'mode': 'This field is required.'},
)
else:
raise rest_framework.serializers.ValidationError(
{'mode': 'Invalid mode.'},
)

return data

def _validate_common(self, data):
"""
Validations for COMMON promo mode.
"""

if 'promo_unique' in data and data['promo_unique'] is not None:
raise rest_framework.serializers.ValidationError(
{'promo_unique': 'This field is not allowed for COMMON mode.'},
)

if self.instance is None and not data.get('promo_common'):
raise rest_framework.serializers.ValidationError(
{'promo_common': 'This field is required for COMMON mode.'},
)

new_max_count = data.get('max_count')
if self.instance and new_max_count is not None:
used_count = self.instance.get_used_codes_count
if used_count > new_max_count:
raise rest_framework.serializers.ValidationError(
{
'max_count': (
f'max_count ({new_max_count}) cannot be less than '
f'used_count ({used_count}).'
),
},
)

effective_max_count = (
new_max_count
if new_max_count is not None
else getattr(self.instance, 'max_count', None)
)

min_c = business.constants.PROMO_COMMON_MIN_COUNT
max_c = business.constants.PROMO_COMMON_MAX_COUNT
if effective_max_count is not None and not (
min_c <= effective_max_count <= max_c
):
raise rest_framework.serializers.ValidationError(
{
'max_count': (
f'Must be between {min_c} and {max_c} for COMMON mode.'
),
},
)

def _validate_unique(self, data):
"""
Validations for UNIQUE promo mode.
"""

if 'promo_common' in data and data['promo_common'] is not None:
raise rest_framework.serializers.ValidationError(
{'promo_common': 'This field is not allowed for UNIQUE mode.'},
)

if self.instance is None and not data.get('promo_unique'):
raise rest_framework.serializers.ValidationError(
{'promo_unique': 'This field is required for UNIQUE mode.'},
)

effective_max_count = data.get(
'max_count',
getattr(self.instance, 'max_count', None),
)

if (
effective_max_count is not None
and effective_max_count
!= business.constants.PROMO_UNIQUE_MAX_COUNT
):
raise rest_framework.serializers.ValidationError(
{
'max_count': (
'Must be equal to '
f'{business.constants.PROMO_UNIQUE_MAX_COUNT} '
'for UNIQUE mode.'
),
},
)

def to_representation(self, instance):
"""
Controls the display of fields in the response.
"""

data = super().to_representation(instance)

if not instance.image_url:
data.pop('image_url', None)

if instance.mode == business.constants.PROMO_MODE_UNIQUE:
data.pop('promo_common', None)
if 'promo_unique' in self.fields and isinstance(
self.fields['promo_unique'],
rest_framework.serializers.SerializerMethodField,
):
data['promo_unique'] = self.get_promo_unique(instance)
else:
data['promo_unique'] = [
code.code for code in instance.unique_codes.all()
]
else:
data.pop('promo_unique', None)

return data


class PromoCreateSerializer(BasePromoSerializer):
class PromoCreateSerializer(core.serializers.BaseCompanyPromoSerializer):
url = rest_framework.serializers.HyperlinkedIdentityField(
view_name='api-business:promo-detail',
lookup_field='id',
)

class Meta(BasePromoSerializer.Meta):
fields = ('url',) + BasePromoSerializer.Meta.fields
class Meta(core.serializers.BaseCompanyPromoSerializer.Meta):
fields = (
'url',
) + core.serializers.BaseCompanyPromoSerializer.Meta.fields

def create(self, validated_data):
target_data = validated_data.pop('target')
Expand Down Expand Up @@ -468,7 +227,7 @@ def validate(self, attrs):
return attrs


class PromoDetailSerializer(BasePromoSerializer):
class PromoDetailSerializer(core.serializers.BaseCompanyPromoSerializer):
promo_id = rest_framework.serializers.UUIDField(
source='id',
read_only=True,
Expand Down Expand Up @@ -496,8 +255,8 @@ class PromoDetailSerializer(BasePromoSerializer):

promo_unique = rest_framework.serializers.SerializerMethodField()

class Meta(BasePromoSerializer.Meta):
fields = BasePromoSerializer.Meta.fields + (
class Meta(core.serializers.BaseCompanyPromoSerializer.Meta):
fields = core.serializers.BaseCompanyPromoSerializer.Meta.fields + (
'promo_id',
'company_name',
'like_count',
Expand All @@ -514,13 +273,12 @@ def get_promo_unique(self, obj):
def update(self, instance, validated_data):
target_data = validated_data.pop('target', None)

for attr, value in validated_data.items():
setattr(instance, attr, value)
instance = super().update(instance, validated_data)

if target_data is not None:
instance.target = target_data
instance.save(update_fields=['target'])

instance.save()
return instance


Expand Down
13 changes: 3 additions & 10 deletions promo_code/business/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,28 +74,21 @@ class CompanyPromoListCreateView(rest_framework.generics.ListCreateAPIView):
rest_framework.permissions.IsAuthenticated,
business.permissions.IsCompanyUser,
]
# Pagination is only needed for GET (listing)
pagination_class = core.pagination.CustomLimitOffsetPagination

_validated_query_params = {}

def get_serializer_class(self):
if self.request.method == 'POST':
return business.serializers.PromoCreateSerializer

return business.serializers.PromoReadOnlySerializer

def list(self, request, *args, **kwargs):
def get_queryset(self):
query_serializer = business.serializers.PromoListQuerySerializer(
data=request.query_params,
data=self.request.query_params,
)
query_serializer.is_valid(raise_exception=True)
self._validated_query_params = query_serializer.validated_data
params = query_serializer.validated_data

return super().list(request, *args, **kwargs)

def get_queryset(self):
params = self._validated_query_params
countries = [c.upper() for c in params.get('countries', [])]
sort_by = params.get('sort_by')

Expand Down
Loading